wgpu_core/
validation.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString as _},
4    vec::Vec,
5};
6use core::fmt;
7
8use arrayvec::ArrayVec;
9use hashbrown::hash_map::Entry;
10use thiserror::Error;
11use wgt::{
12    error::{ErrorType, WebGpuError},
13    BindGroupLayoutEntry, BindingType,
14};
15
16use crate::{device::bgl, resource::InvalidResourceError, FastHashMap, FastHashSet};
17
18#[derive(Debug)]
19enum ResourceType {
20    Buffer {
21        size: wgt::BufferSize,
22    },
23    Texture {
24        dim: naga::ImageDimension,
25        arrayed: bool,
26        class: naga::ImageClass,
27    },
28    Sampler {
29        comparison: bool,
30    },
31    AccelerationStructure {
32        vertex_return: bool,
33    },
34}
35
36#[derive(Clone, Debug)]
37pub enum BindingTypeName {
38    Buffer,
39    Texture,
40    Sampler,
41    AccelerationStructure,
42    ExternalTexture,
43}
44
45impl From<&ResourceType> for BindingTypeName {
46    fn from(ty: &ResourceType) -> BindingTypeName {
47        match ty {
48            ResourceType::Buffer { .. } => BindingTypeName::Buffer,
49            ResourceType::Texture {
50                class: naga::ImageClass::External,
51                ..
52            } => BindingTypeName::ExternalTexture,
53            ResourceType::Texture { .. } => BindingTypeName::Texture,
54            ResourceType::Sampler { .. } => BindingTypeName::Sampler,
55            ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
56        }
57    }
58}
59
60impl From<&BindingType> for BindingTypeName {
61    fn from(ty: &BindingType) -> BindingTypeName {
62        match ty {
63            BindingType::Buffer { .. } => BindingTypeName::Buffer,
64            BindingType::Texture { .. } => BindingTypeName::Texture,
65            BindingType::StorageTexture { .. } => BindingTypeName::Texture,
66            BindingType::Sampler { .. } => BindingTypeName::Sampler,
67            BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
68            BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
69        }
70    }
71}
72
73#[derive(Debug)]
74struct Resource {
75    #[allow(unused)]
76    name: Option<String>,
77    bind: naga::ResourceBinding,
78    ty: ResourceType,
79    class: naga::AddressSpace,
80}
81
82#[derive(Clone, Copy, Debug)]
83enum NumericDimension {
84    Scalar,
85    Vector(naga::VectorSize),
86    Matrix(naga::VectorSize, naga::VectorSize),
87}
88
89impl fmt::Display for NumericDimension {
90    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
91        match *self {
92            Self::Scalar => write!(f, ""),
93            Self::Vector(size) => write!(f, "x{}", size as u8),
94            Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
95        }
96    }
97}
98
99impl NumericDimension {
100    fn num_components(&self) -> u32 {
101        match *self {
102            Self::Scalar => 1,
103            Self::Vector(size) => size as u32,
104            Self::Matrix(w, h) => w as u32 * h as u32,
105        }
106    }
107}
108
109#[derive(Clone, Copy, Debug)]
110pub struct NumericType {
111    dim: NumericDimension,
112    scalar: naga::Scalar,
113}
114
115impl fmt::Display for NumericType {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        write!(
118            f,
119            "{:?}{}{}",
120            self.scalar.kind,
121            self.scalar.width * 8,
122            self.dim
123        )
124    }
125}
126
127#[derive(Clone, Debug)]
128pub struct InterfaceVar {
129    pub ty: NumericType,
130    interpolation: Option<naga::Interpolation>,
131    sampling: Option<naga::Sampling>,
132    per_primitive: bool,
133}
134
135impl InterfaceVar {
136    pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
137        InterfaceVar {
138            ty: NumericType::from_vertex_format(format),
139            interpolation: None,
140            sampling: None,
141            per_primitive: false,
142        }
143    }
144}
145
146impl fmt::Display for InterfaceVar {
147    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
148        write!(
149            f,
150            "{} interpolated as {:?} with sampling {:?}",
151            self.ty, self.interpolation, self.sampling
152        )
153    }
154}
155
156#[derive(Debug)]
157enum Varying {
158    Local { location: u32, iv: InterfaceVar },
159    BuiltIn(naga::BuiltIn),
160}
161
162#[allow(unused)]
163#[derive(Debug)]
164struct SpecializationConstant {
165    id: u32,
166    ty: NumericType,
167}
168
169#[derive(Debug)]
170struct EntryPointMeshInfo {
171    max_vertices: u32,
172    max_primitives: u32,
173}
174
175#[derive(Debug, Default)]
176struct EntryPoint {
177    inputs: Vec<Varying>,
178    outputs: Vec<Varying>,
179    resources: Vec<naga::Handle<Resource>>,
180    #[allow(unused)]
181    spec_constants: Vec<SpecializationConstant>,
182    sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
183    workgroup_size: [u32; 3],
184    dual_source_blending: bool,
185    task_payload_size: Option<u32>,
186    mesh_info: Option<EntryPointMeshInfo>,
187}
188
189#[derive(Debug)]
190pub struct Interface {
191    limits: wgt::Limits,
192    resources: naga::Arena<Resource>,
193    entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
194}
195
196#[derive(Clone, Debug, Error)]
197#[non_exhaustive]
198pub enum BindingError {
199    #[error("Binding is missing from the pipeline layout")]
200    Missing,
201    #[error("Visibility flags don't include the shader stage")]
202    Invisible,
203    #[error(
204        "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
205    )]
206    WrongType {
207        binding: BindingTypeName,
208        shader: BindingTypeName,
209    },
210    #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
211    WrongAddressSpace {
212        binding: naga::AddressSpace,
213        shader: naga::AddressSpace,
214    },
215    #[error("Address space {space:?} is not a valid Buffer address space")]
216    WrongBufferAddressSpace { space: naga::AddressSpace },
217    #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
218    WrongBufferSize {
219        buffer_size: wgt::BufferSize,
220        min_binding_size: wgt::BufferSize,
221    },
222    #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
223    WrongTextureViewDimension {
224        dim: naga::ImageDimension,
225        is_array: bool,
226        binding: BindingType,
227    },
228    #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
229    WrongTextureClass {
230        binding: naga::ImageClass,
231        shader: naga::ImageClass,
232    },
233    #[error("Comparison flag doesn't match the shader")]
234    WrongSamplerComparison,
235    #[error("Derived bind group layout type is not consistent between stages")]
236    InconsistentlyDerivedType,
237    #[error("Texture format {0:?} is not supported for storage use")]
238    BadStorageFormat(wgt::TextureFormat),
239}
240
241impl WebGpuError for BindingError {
242    fn webgpu_error_type(&self) -> ErrorType {
243        ErrorType::Validation
244    }
245}
246
247#[derive(Clone, Debug, Error)]
248#[non_exhaustive]
249pub enum FilteringError {
250    #[error("Integer textures can't be sampled with a filtering sampler")]
251    Integer,
252    #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
253    Float,
254}
255
256impl WebGpuError for FilteringError {
257    fn webgpu_error_type(&self) -> ErrorType {
258        ErrorType::Validation
259    }
260}
261
262#[derive(Clone, Debug, Error)]
263#[non_exhaustive]
264pub enum InputError {
265    #[error("Input is not provided by the earlier stage in the pipeline")]
266    Missing,
267    #[error("Input type is not compatible with the provided {0}")]
268    WrongType(NumericType),
269    #[error("Input interpolation doesn't match provided {0:?}")]
270    InterpolationMismatch(Option<naga::Interpolation>),
271    #[error("Input sampling doesn't match provided {0:?}")]
272    SamplingMismatch(Option<naga::Sampling>),
273    #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
274    WrongPerPrimitive { pipeline_input: bool, shader: bool },
275}
276
277impl WebGpuError for InputError {
278    fn webgpu_error_type(&self) -> ErrorType {
279        ErrorType::Validation
280    }
281}
282
283/// Errors produced when validating a programmable stage of a pipeline.
284#[derive(Clone, Debug, Error)]
285#[non_exhaustive]
286pub enum StageError {
287    #[error(
288        "Shader entry point's workgroup size {current:?} ({current_total} total invocations) must be less or equal to the per-dimension
289        limit `Limits::{per_dimension_limit}` of {limit:?} and the total invocation limit `Limits::{total_limit}` of {total}"
290    )]
291    InvalidWorkgroupSize {
292        current: [u32; 3],
293        current_total: u32,
294        limit: [u32; 3],
295        total: u32,
296        per_dimension_limit: &'static str,
297        total_limit: &'static str,
298    },
299    #[error("Shader uses {used} inter-stage components above the limit of {limit}")]
300    TooManyVaryings { used: u32, limit: u32 },
301    #[error("Unable to find entry point '{0}'")]
302    MissingEntryPoint(String),
303    #[error("Shader global {0:?} is not available in the pipeline layout")]
304    Binding(naga::ResourceBinding, #[source] BindingError),
305    #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
306    Filtering {
307        texture: naga::ResourceBinding,
308        sampler: naga::ResourceBinding,
309        #[source]
310        error: FilteringError,
311    },
312    #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
313    Input {
314        location: wgt::ShaderLocation,
315        var: InterfaceVar,
316        #[source]
317        error: InputError,
318    },
319    #[error(
320        "Unable to select an entry point: no entry point was found in the provided shader module"
321    )]
322    NoEntryPointFound,
323    #[error(
324        "Unable to select an entry point: \
325        multiple entry points were found in the provided shader module, \
326        but no entry point was specified"
327    )]
328    MultipleEntryPointsFound,
329    #[error(transparent)]
330    InvalidResource(#[from] InvalidResourceError),
331    #[error(
332        "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
333    )]
334    ColorAttachmentLocationTooLarge {
335        location: u32,
336        var: InterfaceVar,
337        limit: u32,
338    },
339    #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
340    TooManyMeshVertices { limit: u32, value: u32 },
341    #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
342    TooManyMeshPrimitives { limit: u32, value: u32 },
343    #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
344    TaskPayloadTooLarge { limit: u32, value: u32 },
345    #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
346    TaskPayloadMustMatch {
347        input: Option<u32>,
348        shader: Option<u32>,
349    },
350    #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
351    InvalidPrimitiveIndex,
352    #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
353    MissingPrimitiveIndex,
354    #[error("DrawId cannot be used in the same pipeline as a task shader")]
355    DrawIdError,
356}
357
358impl WebGpuError for StageError {
359    fn webgpu_error_type(&self) -> ErrorType {
360        let e: &dyn WebGpuError = match self {
361            Self::Binding(_, e) => e,
362            Self::InvalidResource(e) => e,
363            Self::Filtering {
364                texture: _,
365                sampler: _,
366                error,
367            } => error,
368            Self::Input {
369                location: _,
370                var: _,
371                error,
372            } => error,
373            Self::InvalidWorkgroupSize { .. }
374            | Self::TooManyVaryings { .. }
375            | Self::MissingEntryPoint(..)
376            | Self::NoEntryPointFound
377            | Self::MultipleEntryPointsFound
378            | Self::ColorAttachmentLocationTooLarge { .. }
379            | Self::TooManyMeshVertices { .. }
380            | Self::TooManyMeshPrimitives { .. }
381            | Self::TaskPayloadTooLarge { .. }
382            | Self::TaskPayloadMustMatch { .. }
383            | Self::InvalidPrimitiveIndex
384            | Self::MissingPrimitiveIndex
385            | Self::DrawIdError => return ErrorType::Validation,
386        };
387        e.webgpu_error_type()
388    }
389}
390
391pub fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> {
392    use naga::StorageFormat as Sf;
393    use wgt::TextureFormat as Tf;
394
395    Some(match format {
396        Tf::R8Unorm => Sf::R8Unorm,
397        Tf::R8Snorm => Sf::R8Snorm,
398        Tf::R8Uint => Sf::R8Uint,
399        Tf::R8Sint => Sf::R8Sint,
400
401        Tf::R16Uint => Sf::R16Uint,
402        Tf::R16Sint => Sf::R16Sint,
403        Tf::R16Float => Sf::R16Float,
404        Tf::Rg8Unorm => Sf::Rg8Unorm,
405        Tf::Rg8Snorm => Sf::Rg8Snorm,
406        Tf::Rg8Uint => Sf::Rg8Uint,
407        Tf::Rg8Sint => Sf::Rg8Sint,
408
409        Tf::R32Uint => Sf::R32Uint,
410        Tf::R32Sint => Sf::R32Sint,
411        Tf::R32Float => Sf::R32Float,
412        Tf::Rg16Uint => Sf::Rg16Uint,
413        Tf::Rg16Sint => Sf::Rg16Sint,
414        Tf::Rg16Float => Sf::Rg16Float,
415        Tf::Rgba8Unorm => Sf::Rgba8Unorm,
416        Tf::Rgba8Snorm => Sf::Rgba8Snorm,
417        Tf::Rgba8Uint => Sf::Rgba8Uint,
418        Tf::Rgba8Sint => Sf::Rgba8Sint,
419        Tf::Bgra8Unorm => Sf::Bgra8Unorm,
420
421        Tf::Rgb10a2Uint => Sf::Rgb10a2Uint,
422        Tf::Rgb10a2Unorm => Sf::Rgb10a2Unorm,
423        Tf::Rg11b10Ufloat => Sf::Rg11b10Ufloat,
424
425        Tf::R64Uint => Sf::R64Uint,
426        Tf::Rg32Uint => Sf::Rg32Uint,
427        Tf::Rg32Sint => Sf::Rg32Sint,
428        Tf::Rg32Float => Sf::Rg32Float,
429        Tf::Rgba16Uint => Sf::Rgba16Uint,
430        Tf::Rgba16Sint => Sf::Rgba16Sint,
431        Tf::Rgba16Float => Sf::Rgba16Float,
432
433        Tf::Rgba32Uint => Sf::Rgba32Uint,
434        Tf::Rgba32Sint => Sf::Rgba32Sint,
435        Tf::Rgba32Float => Sf::Rgba32Float,
436
437        Tf::R16Unorm => Sf::R16Unorm,
438        Tf::R16Snorm => Sf::R16Snorm,
439        Tf::Rg16Unorm => Sf::Rg16Unorm,
440        Tf::Rg16Snorm => Sf::Rg16Snorm,
441        Tf::Rgba16Unorm => Sf::Rgba16Unorm,
442        Tf::Rgba16Snorm => Sf::Rgba16Snorm,
443
444        _ => return None,
445    })
446}
447
448pub fn map_storage_format_from_naga(format: naga::StorageFormat) -> wgt::TextureFormat {
449    use naga::StorageFormat as Sf;
450    use wgt::TextureFormat as Tf;
451
452    match format {
453        Sf::R8Unorm => Tf::R8Unorm,
454        Sf::R8Snorm => Tf::R8Snorm,
455        Sf::R8Uint => Tf::R8Uint,
456        Sf::R8Sint => Tf::R8Sint,
457
458        Sf::R16Uint => Tf::R16Uint,
459        Sf::R16Sint => Tf::R16Sint,
460        Sf::R16Float => Tf::R16Float,
461        Sf::Rg8Unorm => Tf::Rg8Unorm,
462        Sf::Rg8Snorm => Tf::Rg8Snorm,
463        Sf::Rg8Uint => Tf::Rg8Uint,
464        Sf::Rg8Sint => Tf::Rg8Sint,
465
466        Sf::R32Uint => Tf::R32Uint,
467        Sf::R32Sint => Tf::R32Sint,
468        Sf::R32Float => Tf::R32Float,
469        Sf::Rg16Uint => Tf::Rg16Uint,
470        Sf::Rg16Sint => Tf::Rg16Sint,
471        Sf::Rg16Float => Tf::Rg16Float,
472        Sf::Rgba8Unorm => Tf::Rgba8Unorm,
473        Sf::Rgba8Snorm => Tf::Rgba8Snorm,
474        Sf::Rgba8Uint => Tf::Rgba8Uint,
475        Sf::Rgba8Sint => Tf::Rgba8Sint,
476        Sf::Bgra8Unorm => Tf::Bgra8Unorm,
477
478        Sf::Rgb10a2Uint => Tf::Rgb10a2Uint,
479        Sf::Rgb10a2Unorm => Tf::Rgb10a2Unorm,
480        Sf::Rg11b10Ufloat => Tf::Rg11b10Ufloat,
481
482        Sf::R64Uint => Tf::R64Uint,
483        Sf::Rg32Uint => Tf::Rg32Uint,
484        Sf::Rg32Sint => Tf::Rg32Sint,
485        Sf::Rg32Float => Tf::Rg32Float,
486        Sf::Rgba16Uint => Tf::Rgba16Uint,
487        Sf::Rgba16Sint => Tf::Rgba16Sint,
488        Sf::Rgba16Float => Tf::Rgba16Float,
489
490        Sf::Rgba32Uint => Tf::Rgba32Uint,
491        Sf::Rgba32Sint => Tf::Rgba32Sint,
492        Sf::Rgba32Float => Tf::Rgba32Float,
493
494        Sf::R16Unorm => Tf::R16Unorm,
495        Sf::R16Snorm => Tf::R16Snorm,
496        Sf::Rg16Unorm => Tf::Rg16Unorm,
497        Sf::Rg16Snorm => Tf::Rg16Snorm,
498        Sf::Rgba16Unorm => Tf::Rgba16Unorm,
499        Sf::Rgba16Snorm => Tf::Rgba16Snorm,
500    }
501}
502
503impl Resource {
504    fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
505        match self.ty {
506            ResourceType::Buffer { size } => {
507                let min_size = match entry.ty {
508                    BindingType::Buffer {
509                        ty,
510                        has_dynamic_offset: _,
511                        min_binding_size,
512                    } => {
513                        let class = match ty {
514                            wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
515                            wgt::BufferBindingType::Storage { read_only } => {
516                                let mut naga_access = naga::StorageAccess::LOAD;
517                                naga_access.set(naga::StorageAccess::STORE, !read_only);
518                                naga::AddressSpace::Storage {
519                                    access: naga_access,
520                                }
521                            }
522                        };
523                        if self.class != class {
524                            return Err(BindingError::WrongAddressSpace {
525                                binding: class,
526                                shader: self.class,
527                            });
528                        }
529                        min_binding_size
530                    }
531                    _ => {
532                        return Err(BindingError::WrongType {
533                            binding: (&entry.ty).into(),
534                            shader: (&self.ty).into(),
535                        })
536                    }
537                };
538                match min_size {
539                    Some(non_zero) if non_zero < size => {
540                        return Err(BindingError::WrongBufferSize {
541                            buffer_size: size,
542                            min_binding_size: non_zero,
543                        })
544                    }
545                    _ => (),
546                }
547            }
548            ResourceType::Sampler { comparison } => match entry.ty {
549                BindingType::Sampler(ty) => {
550                    if (ty == wgt::SamplerBindingType::Comparison) != comparison {
551                        return Err(BindingError::WrongSamplerComparison);
552                    }
553                }
554                _ => {
555                    return Err(BindingError::WrongType {
556                        binding: (&entry.ty).into(),
557                        shader: (&self.ty).into(),
558                    })
559                }
560            },
561            ResourceType::Texture {
562                dim,
563                arrayed,
564                class,
565            } => {
566                let view_dimension = match entry.ty {
567                    BindingType::Texture { view_dimension, .. }
568                    | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
569                    BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
570                    _ => {
571                        return Err(BindingError::WrongTextureViewDimension {
572                            dim,
573                            is_array: false,
574                            binding: entry.ty,
575                        })
576                    }
577                };
578                if arrayed {
579                    match (dim, view_dimension) {
580                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
581                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
582                        _ => {
583                            return Err(BindingError::WrongTextureViewDimension {
584                                dim,
585                                is_array: true,
586                                binding: entry.ty,
587                            })
588                        }
589                    }
590                } else {
591                    match (dim, view_dimension) {
592                        (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
593                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
594                        (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
595                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
596                        _ => {
597                            return Err(BindingError::WrongTextureViewDimension {
598                                dim,
599                                is_array: false,
600                                binding: entry.ty,
601                            })
602                        }
603                    }
604                }
605                let expected_class = match entry.ty {
606                    BindingType::Texture {
607                        sample_type,
608                        view_dimension: _,
609                        multisampled: multi,
610                    } => match sample_type {
611                        wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
612                            kind: naga::ScalarKind::Float,
613                            multi,
614                        },
615                        wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
616                            kind: naga::ScalarKind::Sint,
617                            multi,
618                        },
619                        wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
620                            kind: naga::ScalarKind::Uint,
621                            multi,
622                        },
623                        wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
624                    },
625                    BindingType::StorageTexture {
626                        access,
627                        format,
628                        view_dimension: _,
629                    } => {
630                        let naga_format = map_storage_format_to_naga(format)
631                            .ok_or(BindingError::BadStorageFormat(format))?;
632                        let naga_access = match access {
633                            wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
634                            wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
635                            wgt::StorageTextureAccess::ReadWrite => {
636                                naga::StorageAccess::LOAD | naga::StorageAccess::STORE
637                            }
638                            wgt::StorageTextureAccess::Atomic => {
639                                naga::StorageAccess::ATOMIC
640                                    | naga::StorageAccess::LOAD
641                                    | naga::StorageAccess::STORE
642                            }
643                        };
644                        naga::ImageClass::Storage {
645                            format: naga_format,
646                            access: naga_access,
647                        }
648                    }
649                    BindingType::ExternalTexture => naga::ImageClass::External,
650                    _ => {
651                        return Err(BindingError::WrongType {
652                            binding: (&entry.ty).into(),
653                            shader: (&self.ty).into(),
654                        })
655                    }
656                };
657                if class != expected_class {
658                    return Err(BindingError::WrongTextureClass {
659                        binding: expected_class,
660                        shader: class,
661                    });
662                }
663            }
664            ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
665                BindingType::AccelerationStructure {
666                    vertex_return: entry_vertex_return,
667                } if vertex_return == entry_vertex_return => (),
668                _ => {
669                    return Err(BindingError::WrongType {
670                        binding: (&entry.ty).into(),
671                        shader: (&self.ty).into(),
672                    })
673                }
674            },
675        };
676
677        Ok(())
678    }
679
680    fn derive_binding_type(
681        &self,
682        is_reffed_by_sampler_in_entrypoint: bool,
683    ) -> Result<BindingType, BindingError> {
684        Ok(match self.ty {
685            ResourceType::Buffer { size } => BindingType::Buffer {
686                ty: match self.class {
687                    naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
688                    naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
689                        read_only: access == naga::StorageAccess::LOAD,
690                    },
691                    _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
692                },
693                has_dynamic_offset: false,
694                min_binding_size: Some(size),
695            },
696            ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
697                wgt::SamplerBindingType::Comparison
698            } else {
699                wgt::SamplerBindingType::Filtering
700            }),
701            ResourceType::Texture {
702                dim,
703                arrayed,
704                class,
705            } => {
706                let view_dimension = match dim {
707                    naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
708                    naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
709                    naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
710                    naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
711                    naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
712                    naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
713                };
714                match class {
715                    naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
716                        sample_type: match kind {
717                            naga::ScalarKind::Float => wgt::TextureSampleType::Float {
718                                filterable: is_reffed_by_sampler_in_entrypoint,
719                            },
720                            naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
721                            naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
722                            naga::ScalarKind::AbstractInt
723                            | naga::ScalarKind::AbstractFloat
724                            | naga::ScalarKind::Bool => unreachable!(),
725                        },
726                        view_dimension,
727                        multisampled: multi,
728                    },
729                    naga::ImageClass::Depth { multi } => BindingType::Texture {
730                        sample_type: wgt::TextureSampleType::Depth,
731                        view_dimension,
732                        multisampled: multi,
733                    },
734                    naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
735                        access: {
736                            const LOAD_STORE: naga::StorageAccess =
737                                naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
738                            match access {
739                                naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
740                                naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
741                                LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
742                                _ if access.contains(naga::StorageAccess::ATOMIC) => {
743                                    wgt::StorageTextureAccess::Atomic
744                                }
745                                _ => unreachable!(),
746                            }
747                        },
748                        view_dimension,
749                        format: {
750                            let f = map_storage_format_from_naga(format);
751                            let original = map_storage_format_to_naga(f)
752                                .ok_or(BindingError::BadStorageFormat(f))?;
753                            debug_assert_eq!(format, original);
754                            f
755                        },
756                    },
757                    naga::ImageClass::External => BindingType::ExternalTexture,
758                }
759            }
760            ResourceType::AccelerationStructure { vertex_return } => {
761                BindingType::AccelerationStructure { vertex_return }
762            }
763        })
764    }
765}
766
767impl NumericType {
768    fn from_vertex_format(format: wgt::VertexFormat) -> Self {
769        use naga::{Scalar, VectorSize as Vs};
770        use wgt::VertexFormat as Vf;
771
772        let (dim, scalar) = match format {
773            Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
774            Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
775                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
776            }
777            Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
778            Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
779                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
780            }
781            Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
782            Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
783                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
784            }
785            Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
786            Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
787                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
788            }
789            Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
790                (NumericDimension::Scalar, Scalar::F32)
791            }
792            Vf::Unorm8x2
793            | Vf::Snorm8x2
794            | Vf::Unorm16x2
795            | Vf::Snorm16x2
796            | Vf::Float16x2
797            | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
798            Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
799            Vf::Unorm8x4
800            | Vf::Snorm8x4
801            | Vf::Unorm16x4
802            | Vf::Snorm16x4
803            | Vf::Float16x4
804            | Vf::Float32x4
805            | Vf::Unorm10_10_10_2
806            | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
807            Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
808            Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
809            Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
810            Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
811        };
812
813        NumericType {
814            dim,
815            //Note: Shader always sees data as int, uint, or float.
816            // It doesn't know if the original is normalized in a tighter form.
817            scalar,
818        }
819    }
820
821    fn from_texture_format(format: wgt::TextureFormat) -> Self {
822        use naga::{Scalar, VectorSize as Vs};
823        use wgt::TextureFormat as Tf;
824
825        let (dim, scalar) = match format {
826            Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
827                (NumericDimension::Scalar, Scalar::F32)
828            }
829            Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
830            Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
831            Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
832                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
833            }
834            Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
835            Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
836                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
837            }
838            Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
839                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
840            }
841            Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
842            Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
843            Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
844            Tf::Rgba8Unorm
845            | Tf::Rgba8UnormSrgb
846            | Tf::Rgba8Snorm
847            | Tf::Bgra8Unorm
848            | Tf::Bgra8UnormSrgb
849            | Tf::Rgb10a2Unorm
850            | Tf::Rgba16Float
851            | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
852            Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
853                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
854            }
855            Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
856                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
857            }
858            Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
859            Tf::Stencil8
860            | Tf::Depth16Unorm
861            | Tf::Depth32Float
862            | Tf::Depth32FloatStencil8
863            | Tf::Depth24Plus
864            | Tf::Depth24PlusStencil8 => {
865                panic!("Unexpected depth format")
866            }
867            Tf::NV12 => panic!("Unexpected nv12 format"),
868            Tf::P010 => panic!("Unexpected p010 format"),
869            Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
870            Tf::Bc1RgbaUnorm
871            | Tf::Bc1RgbaUnormSrgb
872            | Tf::Bc2RgbaUnorm
873            | Tf::Bc2RgbaUnormSrgb
874            | Tf::Bc3RgbaUnorm
875            | Tf::Bc3RgbaUnormSrgb
876            | Tf::Bc7RgbaUnorm
877            | Tf::Bc7RgbaUnormSrgb
878            | Tf::Etc2Rgb8A1Unorm
879            | Tf::Etc2Rgb8A1UnormSrgb
880            | Tf::Etc2Rgba8Unorm
881            | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
882            Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
883                (NumericDimension::Scalar, Scalar::F32)
884            }
885            Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
886                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
887            }
888            Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
889                (NumericDimension::Vector(Vs::Tri), Scalar::F32)
890            }
891            Tf::Astc {
892                block: _,
893                channel: _,
894            } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
895        };
896
897        NumericType {
898            dim,
899            //Note: Shader always sees data as int, uint, or float.
900            // It doesn't know if the original is normalized in a tighter form.
901            scalar,
902        }
903    }
904
905    fn is_subtype_of(&self, other: &NumericType) -> bool {
906        if self.scalar.width > other.scalar.width {
907            return false;
908        }
909        if self.scalar.kind != other.scalar.kind {
910            return false;
911        }
912        match (self.dim, other.dim) {
913            (NumericDimension::Scalar, NumericDimension::Scalar) => true,
914            (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
915            (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
916            (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
917                c0 == c1 && r0 == r1
918            }
919            _ => false,
920        }
921    }
922}
923
924/// Return true if the fragment `format` is covered by the provided `output`.
925pub fn check_texture_format(
926    format: wgt::TextureFormat,
927    output: &NumericType,
928) -> Result<(), NumericType> {
929    let nt = NumericType::from_texture_format(format);
930    if nt.is_subtype_of(output) {
931        Ok(())
932    } else {
933        Err(nt)
934    }
935}
936
937pub enum BindingLayoutSource<'a> {
938    /// The binding layout is derived from the pipeline layout.
939    ///
940    /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces.
941    Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
942    /// The binding layout is provided by the user in BGLs.
943    ///
944    /// This will be validated against the shader's interfaces.
945    Provided(ArrayVec<&'a bgl::EntryMap, { hal::MAX_BIND_GROUPS }>),
946}
947
948impl<'a> BindingLayoutSource<'a> {
949    pub fn new_derived(limits: &wgt::Limits) -> Self {
950        let mut array = ArrayVec::new();
951        for _ in 0..limits.max_bind_groups {
952            array.push(Default::default());
953        }
954        BindingLayoutSource::Derived(Box::new(array))
955    }
956}
957
958#[derive(Debug, Clone, Default)]
959pub struct StageIo {
960    pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
961    /// This must match between mesh & task shaders
962    pub task_payload_size: Option<u32>,
963    /// Fragment shaders cannot input primitive index on mesh shaders that don't output it on DX12.
964    /// Therefore, we track between shader stages if primitive index is written (or if vertex shader
965    /// is used).
966    ///
967    /// This is Some if it was a mesh shader.
968    pub primitive_index: Option<bool>,
969}
970
971impl Interface {
972    fn populate(
973        list: &mut Vec<Varying>,
974        binding: Option<&naga::Binding>,
975        ty: naga::Handle<naga::Type>,
976        arena: &naga::UniqueArena<naga::Type>,
977    ) {
978        let numeric_ty = match arena[ty].inner {
979            naga::TypeInner::Scalar(scalar) => NumericType {
980                dim: NumericDimension::Scalar,
981                scalar,
982            },
983            naga::TypeInner::Vector { size, scalar } => NumericType {
984                dim: NumericDimension::Vector(size),
985                scalar,
986            },
987            naga::TypeInner::Matrix {
988                columns,
989                rows,
990                scalar,
991            } => NumericType {
992                dim: NumericDimension::Matrix(columns, rows),
993                scalar,
994            },
995            naga::TypeInner::Struct { ref members, .. } => {
996                for member in members {
997                    Self::populate(list, member.binding.as_ref(), member.ty, arena);
998                }
999                return;
1000            }
1001            ref other => {
1002                //Note: technically this should be at least `log::error`, but
1003                // the reality is - every shader coming from `glslc` outputs an array
1004                // of clip distances and hits this path :(
1005                // So we lower it to `log::debug` to be less annoying as
1006                // there's nothing the user can do about it.
1007                log::debug!("Unexpected varying type: {other:?}");
1008                return;
1009            }
1010        };
1011
1012        let varying = match binding {
1013            Some(&naga::Binding::Location {
1014                location,
1015                interpolation,
1016                sampling,
1017                per_primitive,
1018                blend_src: _,
1019            }) => Varying::Local {
1020                location,
1021                iv: InterfaceVar {
1022                    ty: numeric_ty,
1023                    interpolation,
1024                    sampling,
1025                    per_primitive,
1026                },
1027            },
1028            Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(built_in),
1029            None => {
1030                log::error!("Missing binding for a varying");
1031                return;
1032            }
1033        };
1034        list.push(varying);
1035    }
1036
1037    pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1038        let mut resources = naga::Arena::new();
1039        let mut resource_mapping = FastHashMap::default();
1040        for (var_handle, var) in module.global_variables.iter() {
1041            let bind = match var.binding {
1042                Some(br) => br,
1043                _ => continue,
1044            };
1045            let naga_ty = &module.types[var.ty].inner;
1046
1047            let inner_ty = match *naga_ty {
1048                naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1049                ref ty => ty,
1050            };
1051
1052            let ty = match *inner_ty {
1053                naga::TypeInner::Image {
1054                    dim,
1055                    arrayed,
1056                    class,
1057                } => ResourceType::Texture {
1058                    dim,
1059                    arrayed,
1060                    class,
1061                },
1062                naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1063                naga::TypeInner::AccelerationStructure { vertex_return } => {
1064                    ResourceType::AccelerationStructure { vertex_return }
1065                }
1066                ref other => ResourceType::Buffer {
1067                    size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1068                },
1069            };
1070            let handle = resources.append(
1071                Resource {
1072                    name: var.name.clone(),
1073                    bind,
1074                    ty,
1075                    class: var.space,
1076                },
1077                Default::default(),
1078            );
1079            resource_mapping.insert(var_handle, handle);
1080        }
1081
1082        let mut entry_points = FastHashMap::default();
1083        entry_points.reserve(module.entry_points.len());
1084        for (index, entry_point) in module.entry_points.iter().enumerate() {
1085            let info = info.get_entry_point(index);
1086            let mut ep = EntryPoint::default();
1087            for arg in entry_point.function.arguments.iter() {
1088                Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1089            }
1090            if let Some(ref result) = entry_point.function.result {
1091                Self::populate(
1092                    &mut ep.outputs,
1093                    result.binding.as_ref(),
1094                    result.ty,
1095                    &module.types,
1096                );
1097            }
1098
1099            for (var_handle, var) in module.global_variables.iter() {
1100                let usage = info[var_handle];
1101                if !usage.is_empty() && var.binding.is_some() {
1102                    ep.resources.push(resource_mapping[&var_handle]);
1103                }
1104            }
1105
1106            for key in info.sampling_set.iter() {
1107                ep.sampling_pairs
1108                    .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1109            }
1110            ep.dual_source_blending = info.dual_source_blending;
1111            ep.workgroup_size = entry_point.workgroup_size;
1112
1113            if let Some(task_payload) = entry_point.task_payload {
1114                ep.task_payload_size = Some(
1115                    module.types[module.global_variables[task_payload].ty]
1116                        .inner
1117                        .size(module.to_ctx()),
1118                );
1119            }
1120            if let Some(ref mesh_info) = entry_point.mesh_info {
1121                ep.mesh_info = Some(EntryPointMeshInfo {
1122                    max_vertices: mesh_info.max_vertices,
1123                    max_primitives: mesh_info.max_primitives,
1124                });
1125                Self::populate(
1126                    &mut ep.outputs,
1127                    None,
1128                    mesh_info.vertex_output_type,
1129                    &module.types,
1130                );
1131                Self::populate(
1132                    &mut ep.outputs,
1133                    None,
1134                    mesh_info.primitive_output_type,
1135                    &module.types,
1136                );
1137            }
1138
1139            entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1140        }
1141
1142        Self {
1143            limits,
1144            resources,
1145            entry_points,
1146        }
1147    }
1148
1149    pub fn finalize_entry_point_name(
1150        &self,
1151        stage_bit: wgt::ShaderStages,
1152        entry_point_name: Option<&str>,
1153    ) -> Result<String, StageError> {
1154        let stage = Self::shader_stage_from_stage_bit(stage_bit);
1155        entry_point_name
1156            .map(|ep| ep.to_string())
1157            .map(Ok)
1158            .unwrap_or_else(|| {
1159                let mut entry_points = self
1160                    .entry_points
1161                    .keys()
1162                    .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1163                let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1164                if entry_points.next().is_some() {
1165                    return Err(StageError::MultipleEntryPointsFound);
1166                }
1167                Ok(first.clone())
1168            })
1169    }
1170
1171    pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage {
1172        match stage_bit {
1173            wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex,
1174            wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment,
1175            wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute,
1176            wgt::ShaderStages::MESH => naga::ShaderStage::Mesh,
1177            wgt::ShaderStages::TASK => naga::ShaderStage::Task,
1178            _ => unreachable!(),
1179        }
1180    }
1181
1182    pub fn check_stage(
1183        &self,
1184        layouts: &mut BindingLayoutSource<'_>,
1185        shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1186        entry_point_name: &str,
1187        stage_bit: wgt::ShaderStages,
1188        inputs: StageIo,
1189        compare_function: Option<wgt::CompareFunction>,
1190    ) -> Result<StageIo, StageError> {
1191        // Since a shader module can have multiple entry points with the same name,
1192        // we need to look for one with the right execution model.
1193        let shader_stage = Self::shader_stage_from_stage_bit(stage_bit);
1194        let pair = (shader_stage, entry_point_name.to_string());
1195        let entry_point = match self.entry_points.get(&pair) {
1196            Some(some) => some,
1197            None => return Err(StageError::MissingEntryPoint(pair.1)),
1198        };
1199        let (_, entry_point_name) = pair;
1200
1201        // check resources visibility
1202        for &handle in entry_point.resources.iter() {
1203            let res = &self.resources[handle];
1204            let result = 'err: {
1205                match layouts {
1206                    BindingLayoutSource::Provided(layouts) => {
1207                        // update the required binding size for this buffer
1208                        if let ResourceType::Buffer { size } = res.ty {
1209                            match shader_binding_sizes.entry(res.bind) {
1210                                Entry::Occupied(e) => {
1211                                    *e.into_mut() = size.max(*e.get());
1212                                }
1213                                Entry::Vacant(e) => {
1214                                    e.insert(size);
1215                                }
1216                            }
1217                        }
1218
1219                        let Some(map) = layouts.get(res.bind.group as usize) else {
1220                            break 'err Err(BindingError::Missing);
1221                        };
1222
1223                        let Some(entry) = map.get(res.bind.binding) else {
1224                            break 'err Err(BindingError::Missing);
1225                        };
1226
1227                        if !entry.visibility.contains(stage_bit) {
1228                            break 'err Err(BindingError::Invisible);
1229                        }
1230
1231                        res.check_binding_use(entry)
1232                    }
1233                    BindingLayoutSource::Derived(layouts) => {
1234                        let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1235                            break 'err Err(BindingError::Missing);
1236                        };
1237
1238                        let ty = match res.derive_binding_type(
1239                            entry_point
1240                                .sampling_pairs
1241                                .iter()
1242                                .any(|&(im, _samp)| im == handle),
1243                        ) {
1244                            Ok(ty) => ty,
1245                            Err(error) => break 'err Err(error),
1246                        };
1247
1248                        match map.entry(res.bind.binding) {
1249                            indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1250                                break 'err Err(BindingError::InconsistentlyDerivedType)
1251                            }
1252                            indexmap::map::Entry::Occupied(e) => {
1253                                e.into_mut().visibility |= stage_bit;
1254                            }
1255                            indexmap::map::Entry::Vacant(e) => {
1256                                e.insert(BindGroupLayoutEntry {
1257                                    binding: res.bind.binding,
1258                                    ty,
1259                                    visibility: stage_bit,
1260                                    count: None,
1261                                });
1262                            }
1263                        }
1264                        Ok(())
1265                    }
1266                }
1267            };
1268            if let Err(error) = result {
1269                return Err(StageError::Binding(res.bind, error));
1270            }
1271        }
1272
1273        // Check the compatibility between textures and samplers
1274        //
1275        // We only need to do this if the binding layout is provided by the user, as derived
1276        // layouts will inherently be correctly tagged.
1277        if let BindingLayoutSource::Provided(layouts) = layouts {
1278            for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1279                let texture_bind = &self.resources[texture_handle].bind;
1280                let sampler_bind = &self.resources[sampler_handle].bind;
1281                let texture_layout = layouts[texture_bind.group as usize]
1282                    .get(texture_bind.binding)
1283                    .unwrap();
1284                let sampler_layout = layouts[sampler_bind.group as usize]
1285                    .get(sampler_bind.binding)
1286                    .unwrap();
1287                assert!(texture_layout.visibility.contains(stage_bit));
1288                assert!(sampler_layout.visibility.contains(stage_bit));
1289
1290                let sampler_filtering = matches!(
1291                    sampler_layout.ty,
1292                    BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1293                );
1294                let texture_sample_type = match texture_layout.ty {
1295                    BindingType::Texture { sample_type, .. } => sample_type,
1296                    BindingType::ExternalTexture => {
1297                        wgt::TextureSampleType::Float { filterable: true }
1298                    }
1299                    _ => unreachable!(),
1300                };
1301
1302                let error = match (sampler_filtering, texture_sample_type) {
1303                    (true, wgt::TextureSampleType::Float { filterable: false }) => {
1304                        Some(FilteringError::Float)
1305                    }
1306                    (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1307                    (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1308                    _ => None,
1309                };
1310
1311                if let Some(error) = error {
1312                    return Err(StageError::Filtering {
1313                        texture: *texture_bind,
1314                        sampler: *sampler_bind,
1315                        error,
1316                    });
1317                }
1318            }
1319        }
1320
1321        // check workgroup size limits
1322        if shader_stage.compute_like() {
1323            let (
1324                max_workgroup_size_limits,
1325                max_workgroup_size_total,
1326                per_dimension_limit,
1327                total_limit,
1328            ) = match shader_stage {
1329                naga::ShaderStage::Compute => (
1330                    [
1331                        self.limits.max_compute_workgroup_size_x,
1332                        self.limits.max_compute_workgroup_size_y,
1333                        self.limits.max_compute_workgroup_size_z,
1334                    ],
1335                    self.limits.max_compute_invocations_per_workgroup,
1336                    "max_compute_workgroup_size_*",
1337                    "max_compute_invocations_per_workgroup",
1338                ),
1339                naga::ShaderStage::Task => (
1340                    [
1341                        self.limits.max_task_invocations_per_dimension,
1342                        self.limits.max_task_invocations_per_dimension,
1343                        self.limits.max_task_invocations_per_dimension,
1344                    ],
1345                    self.limits.max_task_invocations_per_workgroup,
1346                    "max_task_invocations_per_dimension",
1347                    "max_task_invocations_per_workgroup",
1348                ),
1349                naga::ShaderStage::Mesh => (
1350                    [
1351                        self.limits.max_mesh_invocations_per_dimension,
1352                        self.limits.max_mesh_invocations_per_dimension,
1353                        self.limits.max_mesh_invocations_per_dimension,
1354                    ],
1355                    self.limits.max_mesh_invocations_per_workgroup,
1356                    "max_mesh_invocations_per_dimension",
1357                    "max_mesh_invocations_per_workgroup",
1358                ),
1359                _ => unreachable!(),
1360            };
1361            let total_invocations = entry_point.workgroup_size.iter().product::<u32>();
1362
1363            let workgroup_size_is_zero = entry_point.workgroup_size.contains(&0);
1364            let too_many_invocations = total_invocations > max_workgroup_size_total;
1365            let dimension_too_large = entry_point.workgroup_size[0] > max_workgroup_size_limits[0]
1366                || entry_point.workgroup_size[1] > max_workgroup_size_limits[1]
1367                || entry_point.workgroup_size[2] > max_workgroup_size_limits[2];
1368            if workgroup_size_is_zero || too_many_invocations || dimension_too_large {
1369                return Err(StageError::InvalidWorkgroupSize {
1370                    current: entry_point.workgroup_size,
1371                    current_total: total_invocations,
1372                    limit: max_workgroup_size_limits,
1373                    total: max_workgroup_size_total,
1374                    per_dimension_limit,
1375                    total_limit,
1376                });
1377            }
1378        }
1379
1380        let mut inter_stage_components = 0;
1381        let mut this_stage_primitive_index = false;
1382        let mut has_draw_id = false;
1383
1384        // check inputs compatibility
1385        for input in entry_point.inputs.iter() {
1386            match *input {
1387                Varying::Local { location, ref iv } => {
1388                    let result = inputs
1389                        .varyings
1390                        .get(&location)
1391                        .ok_or(InputError::Missing)
1392                        .and_then(|provided| {
1393                            let (compatible, num_components, per_primitive_correct) =
1394                                match shader_stage {
1395                                    // For vertex attributes, there are defaults filled out
1396                                    // by the driver if data is not provided.
1397                                    naga::ShaderStage::Vertex => {
1398                                        let is_compatible =
1399                                            iv.ty.scalar.kind == provided.ty.scalar.kind;
1400                                        // vertex inputs don't count towards inter-stage
1401                                        (is_compatible, 0, !iv.per_primitive)
1402                                    }
1403                                    naga::ShaderStage::Fragment => {
1404                                        if iv.interpolation != provided.interpolation {
1405                                            return Err(InputError::InterpolationMismatch(
1406                                                provided.interpolation,
1407                                            ));
1408                                        }
1409                                        if iv.sampling != provided.sampling {
1410                                            return Err(InputError::SamplingMismatch(
1411                                                provided.sampling,
1412                                            ));
1413                                        }
1414                                        (
1415                                            iv.ty.is_subtype_of(&provided.ty),
1416                                            iv.ty.dim.num_components(),
1417                                            iv.per_primitive == provided.per_primitive,
1418                                        )
1419                                    }
1420                                    // These can't have varying inputs
1421                                    naga::ShaderStage::Compute
1422                                    | naga::ShaderStage::Task
1423                                    | naga::ShaderStage::Mesh => (false, 0, false),
1424                                };
1425                            if !compatible {
1426                                return Err(InputError::WrongType(provided.ty));
1427                            } else if !per_primitive_correct {
1428                                return Err(InputError::WrongPerPrimitive {
1429                                    pipeline_input: provided.per_primitive,
1430                                    shader: iv.per_primitive,
1431                                });
1432                            }
1433                            Ok(num_components)
1434                        });
1435                    match result {
1436                        Ok(num_components) => {
1437                            inter_stage_components += num_components;
1438                        }
1439                        Err(error) => {
1440                            return Err(StageError::Input {
1441                                location,
1442                                var: iv.clone(),
1443                                error,
1444                            })
1445                        }
1446                    }
1447                }
1448                Varying::BuiltIn(naga::BuiltIn::PrimitiveIndex) => {
1449                    this_stage_primitive_index = true;
1450                }
1451                Varying::BuiltIn(naga::BuiltIn::DrawID) => {
1452                    has_draw_id = true;
1453                }
1454                Varying::BuiltIn(_) => {}
1455            }
1456        }
1457
1458        match shader_stage {
1459            naga::ShaderStage::Vertex => {
1460                for output in entry_point.outputs.iter() {
1461                    //TODO: count builtins towards the limit?
1462                    inter_stage_components += match *output {
1463                        Varying::Local { ref iv, .. } => iv.ty.dim.num_components(),
1464                        Varying::BuiltIn(_) => 0,
1465                    };
1466
1467                    if let Some(
1468                        cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1469                    ) = compare_function
1470                    {
1471                        if let Varying::BuiltIn(naga::BuiltIn::Position { invariant: false }) =
1472                            *output
1473                        {
1474                            log::warn!(
1475                                concat!(
1476                                    "Vertex shader with entry point {} outputs a ",
1477                                    "@builtin(position) without the @invariant attribute and ",
1478                                    "is used in a pipeline with {cmp:?}. On some machines, ",
1479                                    "this can cause bad artifacting as {cmp:?} assumes the ",
1480                                    "values output from the vertex shader exactly match the ",
1481                                    "value in the depth buffer. The @invariant attribute on the ",
1482                                    "@builtin(position) vertex output ensures that the exact ",
1483                                    "same pixel depths are used every render."
1484                                ),
1485                                entry_point_name,
1486                                cmp = cmp
1487                            );
1488                        }
1489                    }
1490                }
1491            }
1492            naga::ShaderStage::Fragment => {
1493                for output in &entry_point.outputs {
1494                    let &Varying::Local { location, ref iv } = output else {
1495                        continue;
1496                    };
1497                    if location >= self.limits.max_color_attachments {
1498                        return Err(StageError::ColorAttachmentLocationTooLarge {
1499                            location,
1500                            var: iv.clone(),
1501                            limit: self.limits.max_color_attachments,
1502                        });
1503                    }
1504                }
1505            }
1506            _ => (),
1507        }
1508
1509        if inter_stage_components > self.limits.max_inter_stage_shader_components {
1510            return Err(StageError::TooManyVaryings {
1511                used: inter_stage_components,
1512                limit: self.limits.max_inter_stage_shader_components,
1513            });
1514        }
1515
1516        if let Some(ref mesh_info) = entry_point.mesh_info {
1517            if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1518                return Err(StageError::TooManyMeshVertices {
1519                    limit: self.limits.max_mesh_output_vertices,
1520                    value: mesh_info.max_vertices,
1521                });
1522            }
1523            if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1524                return Err(StageError::TooManyMeshPrimitives {
1525                    limit: self.limits.max_mesh_output_primitives,
1526                    value: mesh_info.max_primitives,
1527                });
1528            }
1529        }
1530        if let Some(task_payload_size) = entry_point.task_payload_size {
1531            if task_payload_size > self.limits.max_task_payload_size {
1532                return Err(StageError::TaskPayloadTooLarge {
1533                    limit: self.limits.max_task_payload_size,
1534                    value: task_payload_size,
1535                });
1536            }
1537        }
1538        if shader_stage == naga::ShaderStage::Mesh
1539            && entry_point.task_payload_size != inputs.task_payload_size
1540        {
1541            return Err(StageError::TaskPayloadMustMatch {
1542                input: inputs.task_payload_size,
1543                shader: entry_point.task_payload_size,
1544            });
1545        }
1546
1547        // Fragment shader primitive index is treated like a varying
1548        if shader_stage == naga::ShaderStage::Fragment
1549            && this_stage_primitive_index
1550            && inputs.primitive_index == Some(false)
1551        {
1552            return Err(StageError::InvalidPrimitiveIndex);
1553        } else if shader_stage == naga::ShaderStage::Fragment
1554            && !this_stage_primitive_index
1555            && inputs.primitive_index == Some(true)
1556        {
1557            return Err(StageError::MissingPrimitiveIndex);
1558        }
1559        if shader_stage == naga::ShaderStage::Mesh
1560            && inputs.task_payload_size.is_some()
1561            && has_draw_id
1562        {
1563            return Err(StageError::DrawIdError);
1564        }
1565
1566        let outputs = entry_point
1567            .outputs
1568            .iter()
1569            .filter_map(|output| match *output {
1570                Varying::Local { location, ref iv } => Some((location, iv.clone())),
1571                Varying::BuiltIn(_) => None,
1572            })
1573            .collect();
1574
1575        Ok(StageIo {
1576            task_payload_size: entry_point.task_payload_size,
1577            varyings: outputs,
1578            primitive_index: if shader_stage == naga::ShaderStage::Mesh {
1579                Some(this_stage_primitive_index)
1580            } else {
1581                None
1582            },
1583        })
1584    }
1585
1586    pub fn fragment_uses_dual_source_blending(
1587        &self,
1588        entry_point_name: &str,
1589    ) -> Result<bool, StageError> {
1590        let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1591        self.entry_points
1592            .get(&pair)
1593            .ok_or(StageError::MissingEntryPoint(pair.1))
1594            .map(|ep| ep.dual_source_blending)
1595    }
1596}
1597
1598/// Validate a list of color attachment formats against `maxColorAttachmentBytesPerSample`.
1599///
1600/// The color attachments can be from a render pass descriptor or a pipeline descriptor.
1601///
1602/// Implements <https://gpuweb.github.io/gpuweb/#abstract-opdef-calculating-color-attachment-bytes-per-sample>.
1603pub fn validate_color_attachment_bytes_per_sample(
1604    attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1605    limit: u32,
1606) -> Result<(), crate::command::ColorAttachmentError> {
1607    let mut total_bytes_per_sample: u32 = 0;
1608    for format in attachment_formats {
1609        let byte_cost = format.target_pixel_byte_cost().unwrap();
1610        let alignment = format.target_component_alignment().unwrap();
1611
1612        total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1613        total_bytes_per_sample += byte_cost;
1614    }
1615
1616    if total_bytes_per_sample > limit {
1617        return Err(
1618            crate::command::ColorAttachmentError::TooManyBytesPerSample {
1619                total: total_bytes_per_sample,
1620                limit,
1621            },
1622        );
1623    }
1624
1625    Ok(())
1626}