Skip to main content

wgpu_core/
validation.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString as _},
4    sync::Arc,
5    vec::Vec,
6};
7use core::fmt;
8
9use arrayvec::ArrayVec;
10use hashbrown::hash_map::Entry;
11use shader_io_deductions::{display_deductions_as_optional_list, MaxVertexShaderOutputDeduction};
12use thiserror::Error;
13use wgt::{
14    error::{ErrorType, WebGpuError},
15    BindGroupLayoutEntry, BindingType,
16};
17
18use crate::{
19    device::bgl, resource::InvalidResourceError,
20    validation::shader_io_deductions::MaxFragmentShaderInputDeduction, FastHashMap, FastHashSet,
21};
22
23pub mod shader_io_deductions;
24
25#[derive(Debug)]
26enum ResourceType {
27    Buffer {
28        size: wgt::BufferSize,
29    },
30    Texture {
31        dim: naga::ImageDimension,
32        arrayed: bool,
33        class: naga::ImageClass,
34    },
35    Sampler {
36        comparison: bool,
37    },
38    AccelerationStructure {
39        vertex_return: bool,
40    },
41}
42
43#[derive(Clone, Debug)]
44pub enum BindingTypeName {
45    Buffer,
46    Texture,
47    Sampler,
48    AccelerationStructure,
49    ExternalTexture,
50}
51
52impl From<&ResourceType> for BindingTypeName {
53    fn from(ty: &ResourceType) -> BindingTypeName {
54        match ty {
55            ResourceType::Buffer { .. } => BindingTypeName::Buffer,
56            ResourceType::Texture {
57                class: naga::ImageClass::External,
58                ..
59            } => BindingTypeName::ExternalTexture,
60            ResourceType::Texture { .. } => BindingTypeName::Texture,
61            ResourceType::Sampler { .. } => BindingTypeName::Sampler,
62            ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
63        }
64    }
65}
66
67impl From<&BindingType> for BindingTypeName {
68    fn from(ty: &BindingType) -> BindingTypeName {
69        match ty {
70            BindingType::Buffer { .. } => BindingTypeName::Buffer,
71            BindingType::Texture { .. } => BindingTypeName::Texture,
72            BindingType::StorageTexture { .. } => BindingTypeName::Texture,
73            BindingType::Sampler { .. } => BindingTypeName::Sampler,
74            BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
75            BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
76        }
77    }
78}
79
80#[derive(Debug)]
81struct Resource {
82    #[allow(unused)]
83    name: Option<String>,
84    bind: naga::ResourceBinding,
85    ty: ResourceType,
86    class: naga::AddressSpace,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
90enum NumericDimension {
91    Scalar,
92    Vector(naga::VectorSize),
93    Matrix(naga::VectorSize, naga::VectorSize),
94}
95
96impl fmt::Display for NumericDimension {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        match *self {
99            Self::Scalar => write!(f, ""),
100            Self::Vector(size) => write!(f, "x{}", size as u8),
101            Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
102        }
103    }
104}
105
106#[derive(Clone, Copy, Debug, Eq, PartialEq)]
107pub struct NumericType {
108    dim: NumericDimension,
109    scalar: naga::Scalar,
110}
111
112impl fmt::Display for NumericType {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(
115            f,
116            "{:?}{}{}",
117            self.scalar.kind,
118            self.scalar.width * 8,
119            self.dim
120        )
121    }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct InterfaceVar {
126    pub ty: NumericType,
127    interpolation: Option<naga::Interpolation>,
128    sampling: Option<naga::Sampling>,
129    per_primitive: bool,
130}
131
132impl InterfaceVar {
133    pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
134        InterfaceVar {
135            ty: NumericType::from_vertex_format(format),
136            interpolation: None,
137            sampling: None,
138            per_primitive: false,
139        }
140    }
141}
142
143impl fmt::Display for InterfaceVar {
144    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145        write!(
146            f,
147            "{} interpolated as {:?} with sampling {:?}",
148            self.ty, self.interpolation, self.sampling
149        )
150    }
151}
152
153#[derive(Debug, Eq, PartialEq)]
154enum Varying {
155    Local { location: u32, iv: InterfaceVar },
156    BuiltIn(naga::BuiltIn),
157}
158
159#[allow(unused)]
160#[derive(Debug)]
161struct SpecializationConstant {
162    id: u32,
163    ty: NumericType,
164}
165
166#[derive(Debug)]
167struct EntryPointMeshInfo {
168    max_vertices: u32,
169    max_primitives: u32,
170}
171
172#[derive(Debug, Default)]
173struct EntryPoint {
174    inputs: Vec<Varying>,
175    outputs: Vec<Varying>,
176    resources: Vec<naga::Handle<Resource>>,
177    #[allow(unused)]
178    spec_constants: Vec<SpecializationConstant>,
179    sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
180    workgroup_size: [u32; 3],
181    dual_source_blending: bool,
182    task_payload_size: Option<u32>,
183    mesh_info: Option<EntryPointMeshInfo>,
184}
185
186#[derive(Debug)]
187pub struct Interface {
188    limits: wgt::Limits,
189    resources: naga::Arena<Resource>,
190    entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
191}
192
193#[derive(Clone, Debug, Error)]
194#[non_exhaustive]
195pub enum BindingError {
196    #[error("Binding is missing from the pipeline layout")]
197    Missing,
198    #[error("Visibility flags don't include the shader stage")]
199    Invisible,
200    #[error(
201        "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
202    )]
203    WrongType {
204        binding: BindingTypeName,
205        shader: BindingTypeName,
206    },
207    #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
208    WrongAddressSpace {
209        binding: naga::AddressSpace,
210        shader: naga::AddressSpace,
211    },
212    #[error("Address space {space:?} is not a valid Buffer address space")]
213    WrongBufferAddressSpace { space: naga::AddressSpace },
214    #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
215    WrongBufferSize {
216        buffer_size: wgt::BufferSize,
217        min_binding_size: wgt::BufferSize,
218    },
219    #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
220    WrongTextureViewDimension {
221        dim: naga::ImageDimension,
222        is_array: bool,
223        binding: BindingType,
224    },
225    #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
226    WrongTextureClass {
227        binding: naga::ImageClass,
228        shader: naga::ImageClass,
229    },
230    #[error("Comparison flag doesn't match the shader")]
231    WrongSamplerComparison,
232    #[error("Derived bind group layout type is not consistent between stages")]
233    InconsistentlyDerivedType,
234    #[error("Texture format {0:?} is not supported for storage use")]
235    BadStorageFormat(wgt::TextureFormat),
236}
237
238impl WebGpuError for BindingError {
239    fn webgpu_error_type(&self) -> ErrorType {
240        ErrorType::Validation
241    }
242}
243
244#[derive(Clone, Debug, Error)]
245#[non_exhaustive]
246pub enum FilteringError {
247    #[error("Integer textures can't be sampled with a filtering sampler")]
248    Integer,
249    #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
250    Float,
251}
252
253impl WebGpuError for FilteringError {
254    fn webgpu_error_type(&self) -> ErrorType {
255        ErrorType::Validation
256    }
257}
258
259#[derive(Clone, Debug, Error)]
260#[non_exhaustive]
261pub enum InputError {
262    #[error("Input is not provided by the earlier stage in the pipeline")]
263    Missing,
264    #[error("Input type is not compatible with the provided {0}")]
265    WrongType(NumericType),
266    #[error("Input interpolation doesn't match provided {0:?}")]
267    InterpolationMismatch(Option<naga::Interpolation>),
268    #[error("Input sampling doesn't match provided {0:?}")]
269    SamplingMismatch(Option<naga::Sampling>),
270    #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
271    WrongPerPrimitive { pipeline_input: bool, shader: bool },
272}
273
274impl WebGpuError for InputError {
275    fn webgpu_error_type(&self) -> ErrorType {
276        ErrorType::Validation
277    }
278}
279
280/// Errors produced when validating a programmable stage of a pipeline.
281#[derive(Clone, Debug, Error)]
282#[non_exhaustive]
283pub enum StageError {
284    #[error(
285        "Shader entry point's workgroup size {dimensions:?} ({total} total invocations) must be \
286        less or equal to the per-dimension limit `Limits::{per_dimension_limits_desc}` of \
287        {per_dimension_limits:?} and the total invocation limit `Limits::{total_limit_desc}` of \
288        {total_limit}"
289    )]
290    InvalidWorkgroupSize {
291        dimensions: [u32; 3],
292        per_dimension_limits: [u32; 3],
293        per_dimension_limits_desc: &'static str,
294        total: u32,
295        total_limit: u32,
296        total_limit_desc: &'static str,
297    },
298    #[error("Unable to find entry point '{0}'")]
299    MissingEntryPoint(String),
300    #[error("Shader global {0:?} is not available in the pipeline layout")]
301    Binding(naga::ResourceBinding, #[source] BindingError),
302    #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
303    Filtering {
304        texture: naga::ResourceBinding,
305        sampler: naga::ResourceBinding,
306        #[source]
307        error: FilteringError,
308    },
309    #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
310    Input {
311        location: wgt::ShaderLocation,
312        var: InterfaceVar,
313        #[source]
314        error: InputError,
315    },
316    #[error(
317        "Unable to select an entry point: no entry point was found in the provided shader module"
318    )]
319    NoEntryPointFound,
320    #[error(
321        "Unable to select an entry point: \
322        multiple entry points were found in the provided shader module, \
323        but no entry point was specified"
324    )]
325    MultipleEntryPointsFound,
326    #[error(transparent)]
327    InvalidResource(#[from] InvalidResourceError),
328    #[error(
329        "vertex shader output location Location[{location}] ({var}) exceeds the \
330        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
331        // NOTE: Remember: the limit is 0-based for indices.
332        limit - 1,
333        display_deductions_as_optional_list(deductions, |d| d.for_location())
334    )]
335    VertexOutputLocationTooLarge {
336        location: u32,
337        var: InterfaceVar,
338        limit: u32,
339        deductions: Vec<MaxVertexShaderOutputDeduction>,
340    },
341    #[error(
342        "found {num_found} user-defined vertex shader output variables, which exceeds the \
343        `max_inter_stage_shader_variables` limit ({limit}){}",
344        display_deductions_as_optional_list(deductions, |d| d.for_variables())
345    )]
346    TooManyUserDefinedVertexOutputs {
347        num_found: u32,
348        limit: u32,
349        deductions: Vec<MaxVertexShaderOutputDeduction>,
350    },
351    #[error(
352        "fragment shader input location Location[{location}] ({var}) exceeds the \
353        `max_inter_stage_shader_variables` limit ({}, 0-based){}",
354        // NOTE: Remember: the limit is 0-based for indices.
355        limit - 1,
356        // NOTE: WebGPU spec. validation for fragment inputs is expressed in terms of variables
357        // (unlike vertex outputs), so we use `MaxFragmentShaderInputDeduction::for_variables` here
358        // (and not a non-existent `for_locations`).
359        display_deductions_as_optional_list(deductions, |d| d.for_variables())
360    )]
361    FragmentInputLocationTooLarge {
362        location: u32,
363        var: InterfaceVar,
364        limit: u32,
365        deductions: Vec<MaxFragmentShaderInputDeduction>,
366    },
367    #[error(
368        "found {num_found} user-defined fragment shader input variables, which exceeds the \
369        `max_inter_stage_shader_variables` limit ({limit}){}",
370        display_deductions_as_optional_list(deductions, |d| d.for_variables())
371    )]
372    TooManyUserDefinedFragmentInputs {
373        num_found: u32,
374        limit: u32,
375        deductions: Vec<MaxFragmentShaderInputDeduction>,
376    },
377    #[error(
378        "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
379    )]
380    ColorAttachmentLocationTooLarge {
381        location: u32,
382        var: InterfaceVar,
383        limit: u32,
384    },
385    #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
386    TooManyMeshVertices { limit: u32, value: u32 },
387    #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
388    TooManyMeshPrimitives { limit: u32, value: u32 },
389    #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
390    TaskPayloadTooLarge { limit: u32, value: u32 },
391    #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
392    TaskPayloadMustMatch {
393        input: Option<u32>,
394        shader: Option<u32>,
395    },
396    #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
397    InvalidPrimitiveIndex,
398    #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
399    MissingPrimitiveIndex,
400    #[error("DrawId cannot be used in the same pipeline as a task shader")]
401    DrawIdError,
402    #[error("Pipeline uses dual-source blending, but the shader does not support it")]
403    InvalidDualSourceBlending,
404    #[error("Fragment shader writes depth, but pipeline does not have a depth attachment")]
405    MissingFragDepthAttachment,
406}
407
408impl WebGpuError for StageError {
409    fn webgpu_error_type(&self) -> ErrorType {
410        match self {
411            Self::Binding(_, e) => e.webgpu_error_type(),
412            Self::InvalidResource(e) => e.webgpu_error_type(),
413            Self::Filtering {
414                texture: _,
415                sampler: _,
416                error,
417            } => error.webgpu_error_type(),
418            Self::Input {
419                location: _,
420                var: _,
421                error,
422            } => error.webgpu_error_type(),
423            Self::InvalidWorkgroupSize { .. }
424            | Self::MissingEntryPoint(..)
425            | Self::NoEntryPointFound
426            | Self::MultipleEntryPointsFound
427            | Self::VertexOutputLocationTooLarge { .. }
428            | Self::TooManyUserDefinedVertexOutputs { .. }
429            | Self::FragmentInputLocationTooLarge { .. }
430            | Self::TooManyUserDefinedFragmentInputs { .. }
431            | Self::ColorAttachmentLocationTooLarge { .. }
432            | Self::TooManyMeshVertices { .. }
433            | Self::TooManyMeshPrimitives { .. }
434            | Self::TaskPayloadTooLarge { .. }
435            | Self::TaskPayloadMustMatch { .. }
436            | Self::InvalidPrimitiveIndex
437            | Self::MissingPrimitiveIndex
438            | Self::DrawIdError
439            | Self::InvalidDualSourceBlending
440            | Self::MissingFragDepthAttachment => ErrorType::Validation,
441        }
442    }
443}
444
445pub use wgpu_naga_bridge::map_storage_format_from_naga;
446pub use wgpu_naga_bridge::map_storage_format_to_naga;
447
448impl Resource {
449    fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
450        match self.ty {
451            ResourceType::Buffer { size } => {
452                let min_size = match entry.ty {
453                    BindingType::Buffer {
454                        ty,
455                        has_dynamic_offset: _,
456                        min_binding_size,
457                    } => {
458                        let class = match ty {
459                            wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
460                            wgt::BufferBindingType::Storage { read_only } => {
461                                let mut naga_access = naga::StorageAccess::LOAD;
462                                naga_access.set(naga::StorageAccess::STORE, !read_only);
463                                naga::AddressSpace::Storage {
464                                    access: naga_access,
465                                }
466                            }
467                        };
468                        if self.class != class {
469                            return Err(BindingError::WrongAddressSpace {
470                                binding: class,
471                                shader: self.class,
472                            });
473                        }
474                        min_binding_size
475                    }
476                    _ => {
477                        return Err(BindingError::WrongType {
478                            binding: (&entry.ty).into(),
479                            shader: (&self.ty).into(),
480                        })
481                    }
482                };
483                match min_size {
484                    Some(non_zero) if non_zero < size => {
485                        return Err(BindingError::WrongBufferSize {
486                            buffer_size: size,
487                            min_binding_size: non_zero,
488                        })
489                    }
490                    _ => (),
491                }
492            }
493            ResourceType::Sampler { comparison } => match entry.ty {
494                BindingType::Sampler(ty) => {
495                    if (ty == wgt::SamplerBindingType::Comparison) != comparison {
496                        return Err(BindingError::WrongSamplerComparison);
497                    }
498                }
499                _ => {
500                    return Err(BindingError::WrongType {
501                        binding: (&entry.ty).into(),
502                        shader: (&self.ty).into(),
503                    })
504                }
505            },
506            ResourceType::Texture {
507                dim,
508                arrayed,
509                class: shader_class,
510            } => {
511                let view_dimension = match entry.ty {
512                    BindingType::Texture { view_dimension, .. }
513                    | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
514                    BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
515                    _ => {
516                        return Err(BindingError::WrongTextureViewDimension {
517                            dim,
518                            is_array: false,
519                            binding: entry.ty,
520                        })
521                    }
522                };
523                if arrayed {
524                    match (dim, view_dimension) {
525                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
526                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
527                        _ => {
528                            return Err(BindingError::WrongTextureViewDimension {
529                                dim,
530                                is_array: true,
531                                binding: entry.ty,
532                            })
533                        }
534                    }
535                } else {
536                    match (dim, view_dimension) {
537                        (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
538                        (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
539                        (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
540                        (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
541                        _ => {
542                            return Err(BindingError::WrongTextureViewDimension {
543                                dim,
544                                is_array: false,
545                                binding: entry.ty,
546                            })
547                        }
548                    }
549                }
550                match entry.ty {
551                    BindingType::Texture {
552                        sample_type,
553                        view_dimension: _,
554                        multisampled: multi,
555                    } => {
556                        let binding_class = match sample_type {
557                            wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
558                                kind: naga::ScalarKind::Float,
559                                multi,
560                            },
561                            wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
562                                kind: naga::ScalarKind::Sint,
563                                multi,
564                            },
565                            wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
566                                kind: naga::ScalarKind::Uint,
567                                multi,
568                            },
569                            wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
570                        };
571                        if shader_class == binding_class {
572                            Ok(())
573                        } else {
574                            Err(binding_class)
575                        }
576                    }
577                    BindingType::StorageTexture {
578                        access: wgt_binding_access,
579                        format: wgt_binding_format,
580                        view_dimension: _,
581                    } => {
582                        const LOAD_STORE: naga::StorageAccess =
583                            naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
584                        let binding_format = map_storage_format_to_naga(wgt_binding_format)
585                            .ok_or(BindingError::BadStorageFormat(wgt_binding_format))?;
586                        let binding_access = match wgt_binding_access {
587                            wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
588                            wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
589                            wgt::StorageTextureAccess::ReadWrite => LOAD_STORE,
590                            wgt::StorageTextureAccess::Atomic => {
591                                naga::StorageAccess::ATOMIC | LOAD_STORE
592                            }
593                        };
594                        match shader_class {
595                            // Formats must match exactly. A write-only shader (but not a
596                            // read-only shader) is compatible with a read-write binding.
597                            naga::ImageClass::Storage {
598                                format: shader_format,
599                                access: shader_access,
600                            } if shader_format == binding_format
601                                && (shader_access == binding_access
602                                    || shader_access == naga::StorageAccess::STORE
603                                        && binding_access == LOAD_STORE) =>
604                            {
605                                Ok(())
606                            }
607                            _ => Err(naga::ImageClass::Storage {
608                                format: binding_format,
609                                access: binding_access,
610                            }),
611                        }
612                    }
613                    BindingType::ExternalTexture => {
614                        let binding_class = naga::ImageClass::External;
615                        if shader_class == binding_class {
616                            Ok(())
617                        } else {
618                            Err(binding_class)
619                        }
620                    }
621                    _ => {
622                        return Err(BindingError::WrongType {
623                            binding: (&entry.ty).into(),
624                            shader: (&self.ty).into(),
625                        })
626                    }
627                }
628                .map_err(|binding_class| BindingError::WrongTextureClass {
629                    binding: binding_class,
630                    shader: shader_class,
631                })?;
632            }
633            ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
634                BindingType::AccelerationStructure {
635                    vertex_return: entry_vertex_return,
636                } if vertex_return == entry_vertex_return => (),
637                _ => {
638                    return Err(BindingError::WrongType {
639                        binding: (&entry.ty).into(),
640                        shader: (&self.ty).into(),
641                    })
642                }
643            },
644        };
645
646        Ok(())
647    }
648
649    fn derive_binding_type(
650        &self,
651        is_reffed_by_sampler_in_entrypoint: bool,
652    ) -> Result<BindingType, BindingError> {
653        Ok(match self.ty {
654            ResourceType::Buffer { size } => BindingType::Buffer {
655                ty: match self.class {
656                    naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
657                    naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
658                        read_only: access == naga::StorageAccess::LOAD,
659                    },
660                    _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
661                },
662                has_dynamic_offset: false,
663                min_binding_size: Some(size),
664            },
665            ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
666                wgt::SamplerBindingType::Comparison
667            } else {
668                wgt::SamplerBindingType::Filtering
669            }),
670            ResourceType::Texture {
671                dim,
672                arrayed,
673                class,
674            } => {
675                let view_dimension = match dim {
676                    naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
677                    naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
678                    naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
679                    naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
680                    naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
681                    naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
682                };
683                match class {
684                    naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
685                        sample_type: match kind {
686                            naga::ScalarKind::Float => wgt::TextureSampleType::Float {
687                                filterable: is_reffed_by_sampler_in_entrypoint,
688                            },
689                            naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
690                            naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
691                            naga::ScalarKind::AbstractInt
692                            | naga::ScalarKind::AbstractFloat
693                            | naga::ScalarKind::Bool => unreachable!(),
694                        },
695                        view_dimension,
696                        multisampled: multi,
697                    },
698                    naga::ImageClass::Depth { multi } => BindingType::Texture {
699                        sample_type: wgt::TextureSampleType::Depth,
700                        view_dimension,
701                        multisampled: multi,
702                    },
703                    naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
704                        access: {
705                            const LOAD_STORE: naga::StorageAccess =
706                                naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
707                            match access {
708                                naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
709                                naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
710                                LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
711                                _ if access.contains(naga::StorageAccess::ATOMIC) => {
712                                    wgt::StorageTextureAccess::Atomic
713                                }
714                                _ => unreachable!(),
715                            }
716                        },
717                        view_dimension,
718                        format: {
719                            let f = map_storage_format_from_naga(format);
720                            let original = map_storage_format_to_naga(f)
721                                .ok_or(BindingError::BadStorageFormat(f))?;
722                            debug_assert_eq!(format, original);
723                            f
724                        },
725                    },
726                    naga::ImageClass::External => BindingType::ExternalTexture,
727                }
728            }
729            ResourceType::AccelerationStructure { vertex_return } => {
730                BindingType::AccelerationStructure { vertex_return }
731            }
732        })
733    }
734}
735
736impl NumericType {
737    fn from_vertex_format(format: wgt::VertexFormat) -> Self {
738        use naga::{Scalar, VectorSize as Vs};
739        use wgt::VertexFormat as Vf;
740
741        let (dim, scalar) = match format {
742            Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
743            Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
744                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
745            }
746            Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
747            Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
748                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
749            }
750            Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
751            Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
752                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
753            }
754            Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
755            Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
756                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
757            }
758            Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
759                (NumericDimension::Scalar, Scalar::F32)
760            }
761            Vf::Unorm8x2
762            | Vf::Snorm8x2
763            | Vf::Unorm16x2
764            | Vf::Snorm16x2
765            | Vf::Float16x2
766            | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
767            Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
768            Vf::Unorm8x4
769            | Vf::Snorm8x4
770            | Vf::Unorm16x4
771            | Vf::Snorm16x4
772            | Vf::Float16x4
773            | Vf::Float32x4
774            | Vf::Unorm10_10_10_2
775            | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
776            Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
777            Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
778            Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
779            Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
780        };
781
782        NumericType {
783            dim,
784            //Note: Shader always sees data as int, uint, or float.
785            // It doesn't know if the original is normalized in a tighter form.
786            scalar,
787        }
788    }
789
790    fn from_texture_format(format: wgt::TextureFormat) -> Self {
791        use naga::{Scalar, VectorSize as Vs};
792        use wgt::TextureFormat as Tf;
793
794        let (dim, scalar) = match format {
795            Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
796                (NumericDimension::Scalar, Scalar::F32)
797            }
798            Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
799            Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
800            Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
801                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
802            }
803            Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
804            Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
805                (NumericDimension::Vector(Vs::Bi), Scalar::U32)
806            }
807            Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
808                (NumericDimension::Vector(Vs::Bi), Scalar::I32)
809            }
810            Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
811            Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
812            Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
813            Tf::Rgba8Unorm
814            | Tf::Rgba8UnormSrgb
815            | Tf::Rgba8Snorm
816            | Tf::Bgra8Unorm
817            | Tf::Bgra8UnormSrgb
818            | Tf::Rgb10a2Unorm
819            | Tf::Rgba16Float
820            | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
821            Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
822                (NumericDimension::Vector(Vs::Quad), Scalar::U32)
823            }
824            Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
825                (NumericDimension::Vector(Vs::Quad), Scalar::I32)
826            }
827            Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
828            Tf::Stencil8
829            | Tf::Depth16Unorm
830            | Tf::Depth32Float
831            | Tf::Depth32FloatStencil8
832            | Tf::Depth24Plus
833            | Tf::Depth24PlusStencil8 => {
834                panic!("Unexpected depth format")
835            }
836            Tf::NV12 => panic!("Unexpected nv12 format"),
837            Tf::P010 => panic!("Unexpected p010 format"),
838            Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
839            Tf::Bc1RgbaUnorm
840            | Tf::Bc1RgbaUnormSrgb
841            | Tf::Bc2RgbaUnorm
842            | Tf::Bc2RgbaUnormSrgb
843            | Tf::Bc3RgbaUnorm
844            | Tf::Bc3RgbaUnormSrgb
845            | Tf::Bc7RgbaUnorm
846            | Tf::Bc7RgbaUnormSrgb
847            | Tf::Etc2Rgb8A1Unorm
848            | Tf::Etc2Rgb8A1UnormSrgb
849            | Tf::Etc2Rgba8Unorm
850            | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
851            Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
852                (NumericDimension::Scalar, Scalar::F32)
853            }
854            Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
855                (NumericDimension::Vector(Vs::Bi), Scalar::F32)
856            }
857            Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
858                (NumericDimension::Vector(Vs::Tri), Scalar::F32)
859            }
860            Tf::Astc {
861                block: _,
862                channel: _,
863            } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
864        };
865
866        NumericType {
867            dim,
868            //Note: Shader always sees data as int, uint, or float.
869            // It doesn't know if the original is normalized in a tighter form.
870            scalar,
871        }
872    }
873
874    fn is_subtype_of(&self, other: &NumericType) -> bool {
875        if self.scalar.width > other.scalar.width {
876            return false;
877        }
878        if self.scalar.kind != other.scalar.kind {
879            return false;
880        }
881        match (self.dim, other.dim) {
882            (NumericDimension::Scalar, NumericDimension::Scalar) => true,
883            (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
884            (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
885            (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
886                c0 == c1 && r0 == r1
887            }
888            _ => false,
889        }
890    }
891}
892
893/// Return true if the fragment `format` is covered by the provided `output`.
894pub fn check_texture_format(
895    format: wgt::TextureFormat,
896    output: &NumericType,
897) -> Result<(), NumericType> {
898    let nt = NumericType::from_texture_format(format);
899    if nt.is_subtype_of(output) {
900        Ok(())
901    } else {
902        Err(nt)
903    }
904}
905
906pub enum BindingLayoutSource {
907    /// The binding layout is derived from the pipeline layout.
908    ///
909    /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces.
910    Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
911    /// The binding layout is provided by the user in BGLs.
912    ///
913    /// This will be validated against the shader's interfaces.
914    Provided(Arc<crate::binding_model::PipelineLayout>),
915}
916
917impl BindingLayoutSource {
918    pub fn new_derived(limits: &wgt::Limits) -> Self {
919        let mut array = ArrayVec::new();
920        for _ in 0..limits.max_bind_groups {
921            array.push(Default::default());
922        }
923        BindingLayoutSource::Derived(Box::new(array))
924    }
925}
926
927#[derive(Debug, Clone, Default)]
928pub struct StageIo {
929    pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
930    /// This must match between mesh & task shaders
931    pub task_payload_size: Option<u32>,
932    /// Fragment shaders cannot input primitive index on mesh shaders that don't output it on DX12.
933    /// Therefore, we track between shader stages if primitive index is written (or if vertex shader
934    /// is used).
935    ///
936    /// This is Some if it was a mesh shader.
937    pub primitive_index: Option<bool>,
938}
939
940impl Interface {
941    fn populate(
942        list: &mut Vec<Varying>,
943        binding: Option<&naga::Binding>,
944        ty: naga::Handle<naga::Type>,
945        arena: &naga::UniqueArena<naga::Type>,
946    ) {
947        let numeric_ty = match arena[ty].inner {
948            naga::TypeInner::Scalar(scalar) => NumericType {
949                dim: NumericDimension::Scalar,
950                scalar,
951            },
952            naga::TypeInner::Vector { size, scalar } => NumericType {
953                dim: NumericDimension::Vector(size),
954                scalar,
955            },
956            naga::TypeInner::Matrix {
957                columns,
958                rows,
959                scalar,
960            } => NumericType {
961                dim: NumericDimension::Matrix(columns, rows),
962                scalar,
963            },
964            naga::TypeInner::Struct { ref members, .. } => {
965                for member in members {
966                    Self::populate(list, member.binding.as_ref(), member.ty, arena);
967                }
968                return;
969            }
970            ref other => {
971                //Note: technically this should be at least `log::error`, but
972                // the reality is - every shader coming from `glslc` outputs an array
973                // of clip distances and hits this path :(
974                // So we lower it to `log::debug` to be less annoying as
975                // there's nothing the user can do about it.
976                log::debug!("Unexpected varying type: {other:?}");
977                return;
978            }
979        };
980
981        let varying = match binding {
982            Some(&naga::Binding::Location {
983                location,
984                interpolation,
985                sampling,
986                per_primitive,
987                blend_src: _,
988            }) => Varying::Local {
989                location,
990                iv: InterfaceVar {
991                    ty: numeric_ty,
992                    interpolation,
993                    sampling,
994                    per_primitive,
995                },
996            },
997            Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(built_in),
998            None => {
999                log::error!("Missing binding for a varying");
1000                return;
1001            }
1002        };
1003        list.push(varying);
1004    }
1005
1006    pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1007        let mut resources = naga::Arena::new();
1008        let mut resource_mapping = FastHashMap::default();
1009        for (var_handle, var) in module.global_variables.iter() {
1010            let bind = match var.binding {
1011                Some(br) => br,
1012                _ => continue,
1013            };
1014            let naga_ty = &module.types[var.ty].inner;
1015
1016            let inner_ty = match *naga_ty {
1017                naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1018                ref ty => ty,
1019            };
1020
1021            let ty = match *inner_ty {
1022                naga::TypeInner::Image {
1023                    dim,
1024                    arrayed,
1025                    class,
1026                } => ResourceType::Texture {
1027                    dim,
1028                    arrayed,
1029                    class,
1030                },
1031                naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1032                naga::TypeInner::AccelerationStructure { vertex_return } => {
1033                    ResourceType::AccelerationStructure { vertex_return }
1034                }
1035                ref other => ResourceType::Buffer {
1036                    size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1037                },
1038            };
1039            let handle = resources.append(
1040                Resource {
1041                    name: var.name.clone(),
1042                    bind,
1043                    ty,
1044                    class: var.space,
1045                },
1046                Default::default(),
1047            );
1048            resource_mapping.insert(var_handle, handle);
1049        }
1050
1051        let mut entry_points = FastHashMap::default();
1052        entry_points.reserve(module.entry_points.len());
1053        for (index, entry_point) in module.entry_points.iter().enumerate() {
1054            let info = info.get_entry_point(index);
1055            let mut ep = EntryPoint::default();
1056            for arg in entry_point.function.arguments.iter() {
1057                Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1058            }
1059            if let Some(ref result) = entry_point.function.result {
1060                Self::populate(
1061                    &mut ep.outputs,
1062                    result.binding.as_ref(),
1063                    result.ty,
1064                    &module.types,
1065                );
1066            }
1067
1068            for (var_handle, var) in module.global_variables.iter() {
1069                let usage = info[var_handle];
1070                if !usage.is_empty() && var.binding.is_some() {
1071                    ep.resources.push(resource_mapping[&var_handle]);
1072                }
1073            }
1074
1075            for key in info.sampling_set.iter() {
1076                ep.sampling_pairs
1077                    .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1078            }
1079            ep.dual_source_blending = info.dual_source_blending;
1080            ep.workgroup_size = entry_point.workgroup_size;
1081
1082            if let Some(task_payload) = entry_point.task_payload {
1083                ep.task_payload_size = Some(
1084                    module.types[module.global_variables[task_payload].ty]
1085                        .inner
1086                        .size(module.to_ctx()),
1087                );
1088            }
1089            if let Some(ref mesh_info) = entry_point.mesh_info {
1090                ep.mesh_info = Some(EntryPointMeshInfo {
1091                    max_vertices: mesh_info.max_vertices,
1092                    max_primitives: mesh_info.max_primitives,
1093                });
1094                Self::populate(
1095                    &mut ep.outputs,
1096                    None,
1097                    mesh_info.vertex_output_type,
1098                    &module.types,
1099                );
1100                Self::populate(
1101                    &mut ep.outputs,
1102                    None,
1103                    mesh_info.primitive_output_type,
1104                    &module.types,
1105                );
1106            }
1107
1108            entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1109        }
1110
1111        Self {
1112            limits,
1113            resources,
1114            entry_points,
1115        }
1116    }
1117
1118    pub fn finalize_entry_point_name(
1119        &self,
1120        stage: naga::ShaderStage,
1121        entry_point_name: Option<&str>,
1122    ) -> Result<String, StageError> {
1123        entry_point_name
1124            .map(|ep| ep.to_string())
1125            .map(Ok)
1126            .unwrap_or_else(|| {
1127                let mut entry_points = self
1128                    .entry_points
1129                    .keys()
1130                    .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1131                let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1132                if entry_points.next().is_some() {
1133                    return Err(StageError::MultipleEntryPointsFound);
1134                }
1135                Ok(first.clone())
1136            })
1137    }
1138
1139    /// Among other things, this implements some validation logic defined by the WebGPU spec. at
1140    /// <https://www.w3.org/TR/webgpu/#abstract-opdef-validating-inter-stage-interfaces>.
1141    pub fn check_stage(
1142        &self,
1143        layouts: &mut BindingLayoutSource,
1144        shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1145        entry_point_name: &str,
1146        shader_stage: ShaderStageForValidation,
1147        inputs: StageIo,
1148    ) -> Result<StageIo, StageError> {
1149        // Since a shader module can have multiple entry points with the same name,
1150        // we need to look for one with the right execution model.
1151        let pair = (shader_stage.to_naga(), entry_point_name.to_string());
1152        let entry_point = match self.entry_points.get(&pair) {
1153            Some(some) => some,
1154            None => return Err(StageError::MissingEntryPoint(pair.1)),
1155        };
1156        let (_, entry_point_name) = pair;
1157
1158        let stage_bit = shader_stage.to_wgt_bit();
1159
1160        // check resources visibility
1161        for &handle in entry_point.resources.iter() {
1162            let res = &self.resources[handle];
1163            let result = 'err: {
1164                match layouts {
1165                    BindingLayoutSource::Provided(pipeline_layout) => {
1166                        // update the required binding size for this buffer
1167                        if let ResourceType::Buffer { size } = res.ty {
1168                            match shader_binding_sizes.entry(res.bind) {
1169                                Entry::Occupied(e) => {
1170                                    *e.into_mut() = size.max(*e.get());
1171                                }
1172                                Entry::Vacant(e) => {
1173                                    e.insert(size);
1174                                }
1175                            }
1176                        }
1177
1178                        let Some(entry) =
1179                            pipeline_layout.get_bgl_entry(res.bind.group, res.bind.binding)
1180                        else {
1181                            break 'err Err(BindingError::Missing);
1182                        };
1183
1184                        if !entry.visibility.contains(stage_bit) {
1185                            break 'err Err(BindingError::Invisible);
1186                        }
1187
1188                        res.check_binding_use(entry)
1189                    }
1190                    BindingLayoutSource::Derived(layouts) => {
1191                        let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1192                            break 'err Err(BindingError::Missing);
1193                        };
1194
1195                        let ty = match res.derive_binding_type(
1196                            entry_point
1197                                .sampling_pairs
1198                                .iter()
1199                                .any(|&(im, _samp)| im == handle),
1200                        ) {
1201                            Ok(ty) => ty,
1202                            Err(error) => break 'err Err(error),
1203                        };
1204
1205                        match map.entry(res.bind.binding) {
1206                            indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1207                                break 'err Err(BindingError::InconsistentlyDerivedType)
1208                            }
1209                            indexmap::map::Entry::Occupied(e) => {
1210                                e.into_mut().visibility |= stage_bit;
1211                            }
1212                            indexmap::map::Entry::Vacant(e) => {
1213                                e.insert(BindGroupLayoutEntry {
1214                                    binding: res.bind.binding,
1215                                    ty,
1216                                    visibility: stage_bit,
1217                                    count: None,
1218                                });
1219                            }
1220                        }
1221                        Ok(())
1222                    }
1223                }
1224            };
1225            if let Err(error) = result {
1226                return Err(StageError::Binding(res.bind, error));
1227            }
1228        }
1229
1230        // Check the compatibility between textures and samplers
1231        //
1232        // We only need to do this if the binding layout is provided by the user, as derived
1233        // layouts will inherently be correctly tagged.
1234        if let BindingLayoutSource::Provided(pipeline_layout) = layouts {
1235            for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1236                let texture_bind = &self.resources[texture_handle].bind;
1237                let sampler_bind = &self.resources[sampler_handle].bind;
1238                let texture_layout = pipeline_layout
1239                    .get_bgl_entry(texture_bind.group, texture_bind.binding)
1240                    .unwrap();
1241                let sampler_layout = pipeline_layout
1242                    .get_bgl_entry(sampler_bind.group, sampler_bind.binding)
1243                    .unwrap();
1244                assert!(texture_layout.visibility.contains(stage_bit));
1245                assert!(sampler_layout.visibility.contains(stage_bit));
1246
1247                let sampler_filtering = matches!(
1248                    sampler_layout.ty,
1249                    BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1250                );
1251                let texture_sample_type = match texture_layout.ty {
1252                    BindingType::Texture { sample_type, .. } => sample_type,
1253                    BindingType::ExternalTexture => {
1254                        wgt::TextureSampleType::Float { filterable: true }
1255                    }
1256                    _ => unreachable!(),
1257                };
1258
1259                let error = match (sampler_filtering, texture_sample_type) {
1260                    (true, wgt::TextureSampleType::Float { filterable: false }) => {
1261                        Some(FilteringError::Float)
1262                    }
1263                    (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1264                    (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1265                    _ => None,
1266                };
1267
1268                if let Some(error) = error {
1269                    return Err(StageError::Filtering {
1270                        texture: *texture_bind,
1271                        sampler: *sampler_bind,
1272                        error,
1273                    });
1274                }
1275            }
1276        }
1277
1278        // check workgroup size limits
1279        if shader_stage.to_naga().compute_like() {
1280            let (
1281                max_workgroup_size_limits,
1282                max_workgroup_size_total,
1283                per_dimension_limit,
1284                total_limit,
1285            ) = match shader_stage.to_naga() {
1286                naga::ShaderStage::Compute => (
1287                    [
1288                        self.limits.max_compute_workgroup_size_x,
1289                        self.limits.max_compute_workgroup_size_y,
1290                        self.limits.max_compute_workgroup_size_z,
1291                    ],
1292                    self.limits.max_compute_invocations_per_workgroup,
1293                    "max_compute_workgroup_size_*",
1294                    "max_compute_invocations_per_workgroup",
1295                ),
1296                naga::ShaderStage::Task => (
1297                    [
1298                        self.limits.max_task_invocations_per_dimension,
1299                        self.limits.max_task_invocations_per_dimension,
1300                        self.limits.max_task_invocations_per_dimension,
1301                    ],
1302                    self.limits.max_task_invocations_per_workgroup,
1303                    "max_task_invocations_per_dimension",
1304                    "max_task_invocations_per_workgroup",
1305                ),
1306                naga::ShaderStage::Mesh => (
1307                    [
1308                        self.limits.max_mesh_invocations_per_dimension,
1309                        self.limits.max_mesh_invocations_per_dimension,
1310                        self.limits.max_mesh_invocations_per_dimension,
1311                    ],
1312                    self.limits.max_mesh_invocations_per_workgroup,
1313                    "max_mesh_invocations_per_dimension",
1314                    "max_mesh_invocations_per_workgroup",
1315                ),
1316                _ => unreachable!(),
1317            };
1318
1319            let total_invocations = entry_point
1320                .workgroup_size
1321                .iter()
1322                .fold(1u32, |total, &dim| total.saturating_mul(dim));
1323            let invalid_total_invocations =
1324                total_invocations > max_workgroup_size_total || total_invocations == 0;
1325
1326            assert_eq!(
1327                entry_point.workgroup_size.len(),
1328                max_workgroup_size_limits.len()
1329            );
1330            let dimension_too_large = entry_point
1331                .workgroup_size
1332                .iter()
1333                .zip(max_workgroup_size_limits.iter())
1334                .any(|(dim, limit)| dim > limit);
1335
1336            if invalid_total_invocations || dimension_too_large {
1337                return Err(StageError::InvalidWorkgroupSize {
1338                    dimensions: entry_point.workgroup_size,
1339                    total: total_invocations,
1340                    per_dimension_limits: max_workgroup_size_limits,
1341                    total_limit: max_workgroup_size_total,
1342                    per_dimension_limits_desc: per_dimension_limit,
1343                    total_limit_desc: total_limit,
1344                });
1345            }
1346        }
1347
1348        let mut this_stage_primitive_index = false;
1349        let mut has_draw_id = false;
1350
1351        // check inputs compatibility
1352        for input in entry_point.inputs.iter() {
1353            match *input {
1354                Varying::Local { location, ref iv } => {
1355                    let result = inputs
1356                        .varyings
1357                        .get(&location)
1358                        .ok_or(InputError::Missing)
1359                        .and_then(|provided| {
1360                            let (compatible, per_primitive_correct) = match shader_stage.to_naga() {
1361                                // For vertex attributes, there are defaults filled out
1362                                // by the driver if data is not provided.
1363                                naga::ShaderStage::Vertex => {
1364                                    let is_compatible =
1365                                        iv.ty.scalar.kind == provided.ty.scalar.kind;
1366                                    // vertex inputs don't count towards inter-stage
1367                                    (is_compatible, !iv.per_primitive)
1368                                }
1369                                naga::ShaderStage::Fragment => {
1370                                    if iv.interpolation != provided.interpolation {
1371                                        return Err(InputError::InterpolationMismatch(
1372                                            provided.interpolation,
1373                                        ));
1374                                    }
1375                                    if iv.sampling != provided.sampling {
1376                                        return Err(InputError::SamplingMismatch(
1377                                            provided.sampling,
1378                                        ));
1379                                    }
1380                                    (
1381                                        iv.ty.is_subtype_of(&provided.ty),
1382                                        iv.per_primitive == provided.per_primitive,
1383                                    )
1384                                }
1385                                // These can't have varying inputs
1386                                naga::ShaderStage::Compute
1387                                | naga::ShaderStage::Task
1388                                | naga::ShaderStage::Mesh => (false, false),
1389                                naga::ShaderStage::RayGeneration
1390                                | naga::ShaderStage::AnyHit
1391                                | naga::ShaderStage::ClosestHit
1392                                | naga::ShaderStage::Miss => {
1393                                    unreachable!()
1394                                }
1395                            };
1396                            if !compatible {
1397                                return Err(InputError::WrongType(provided.ty));
1398                            } else if !per_primitive_correct {
1399                                return Err(InputError::WrongPerPrimitive {
1400                                    pipeline_input: provided.per_primitive,
1401                                    shader: iv.per_primitive,
1402                                });
1403                            }
1404                            Ok(())
1405                        });
1406
1407                    if let Err(error) = result {
1408                        return Err(StageError::Input {
1409                            location,
1410                            var: iv.clone(),
1411                            error,
1412                        });
1413                    }
1414                }
1415                Varying::BuiltIn(naga::BuiltIn::PrimitiveIndex) => {
1416                    this_stage_primitive_index = true;
1417                }
1418                Varying::BuiltIn(naga::BuiltIn::DrawIndex) => {
1419                    has_draw_id = true;
1420                }
1421                Varying::BuiltIn(_) => {}
1422            }
1423        }
1424
1425        match shader_stage {
1426            ShaderStageForValidation::Vertex {
1427                topology,
1428                compare_function,
1429            } => {
1430                let mut max_vertex_shader_output_variables =
1431                    self.limits.max_inter_stage_shader_variables;
1432                let mut max_vertex_shader_output_location = max_vertex_shader_output_variables - 1;
1433
1434                let point_list_deduction = if topology == wgt::PrimitiveTopology::PointList {
1435                    Some(MaxVertexShaderOutputDeduction::PointListPrimitiveTopology)
1436                } else {
1437                    None
1438                };
1439
1440                let deductions = point_list_deduction.into_iter();
1441
1442                for deduction in deductions.clone() {
1443                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1444                    // ever exceed the minimum variables available.
1445                    max_vertex_shader_output_variables = max_vertex_shader_output_variables
1446                        .checked_sub(deduction.for_variables())
1447                        .unwrap();
1448                    max_vertex_shader_output_location = max_vertex_shader_output_location
1449                        .checked_sub(deduction.for_location())
1450                        .unwrap();
1451                }
1452
1453                let mut num_user_defined_outputs = 0;
1454
1455                for output in entry_point.outputs.iter() {
1456                    match *output {
1457                        Varying::Local { ref iv, location } => {
1458                            if location > max_vertex_shader_output_location {
1459                                return Err(StageError::VertexOutputLocationTooLarge {
1460                                    location,
1461                                    var: iv.clone(),
1462                                    limit: self.limits.max_inter_stage_shader_variables,
1463                                    deductions: deductions.collect(),
1464                                });
1465                            }
1466                            num_user_defined_outputs += 1;
1467                        }
1468                        Varying::BuiltIn(_) => {}
1469                    };
1470
1471                    if let Some(
1472                        cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1473                    ) = compare_function
1474                    {
1475                        if let Varying::BuiltIn(naga::BuiltIn::Position { invariant: false }) =
1476                            *output
1477                        {
1478                            log::warn!(
1479                                concat!(
1480                                    "Vertex shader with entry point {} outputs a ",
1481                                    "@builtin(position) without the @invariant attribute and ",
1482                                    "is used in a pipeline with {cmp:?}. On some machines, ",
1483                                    "this can cause bad artifacting as {cmp:?} assumes the ",
1484                                    "values output from the vertex shader exactly match the ",
1485                                    "value in the depth buffer. The @invariant attribute on the ",
1486                                    "@builtin(position) vertex output ensures that the exact ",
1487                                    "same pixel depths are used every render."
1488                                ),
1489                                entry_point_name,
1490                                cmp = cmp
1491                            );
1492                        }
1493                    }
1494                }
1495
1496                if num_user_defined_outputs > max_vertex_shader_output_variables {
1497                    return Err(StageError::TooManyUserDefinedVertexOutputs {
1498                        num_found: num_user_defined_outputs,
1499                        limit: self.limits.max_inter_stage_shader_variables,
1500                        deductions: deductions.collect(),
1501                    });
1502                }
1503            }
1504            ShaderStageForValidation::Fragment {
1505                dual_source_blending,
1506                has_depth_attachment,
1507            } => {
1508                let mut max_fragment_shader_input_variables =
1509                    self.limits.max_inter_stage_shader_variables;
1510
1511                let deductions = entry_point.inputs.iter().filter_map(|output| match output {
1512                    Varying::Local { .. } => None,
1513                    Varying::BuiltIn(builtin) => {
1514                        MaxFragmentShaderInputDeduction::from_inter_stage_builtin(*builtin).or_else(
1515                            || {
1516                                unreachable!(
1517                                    concat!(
1518                                        "unexpected built-in provided; ",
1519                                        "{:?} is not used for fragment stage input",
1520                                    ),
1521                                    builtin
1522                                )
1523                            },
1524                        )
1525                    }
1526                });
1527
1528                for deduction in deductions.clone() {
1529                    // NOTE: Deductions, in the current version of the spec. we implement, do not
1530                    // ever exceed the minimum variables available.
1531                    max_fragment_shader_input_variables = max_fragment_shader_input_variables
1532                        .checked_sub(deduction.for_variables())
1533                        .unwrap();
1534                }
1535
1536                let mut num_user_defined_inputs = 0;
1537
1538                for output in entry_point.inputs.iter() {
1539                    match *output {
1540                        Varying::Local { ref iv, location } => {
1541                            if location >= self.limits.max_inter_stage_shader_variables {
1542                                return Err(StageError::FragmentInputLocationTooLarge {
1543                                    location,
1544                                    var: iv.clone(),
1545                                    limit: self.limits.max_inter_stage_shader_variables,
1546                                    deductions: deductions.collect(),
1547                                });
1548                            }
1549                            num_user_defined_inputs += 1;
1550                        }
1551                        Varying::BuiltIn(_) => {}
1552                    };
1553                }
1554
1555                if num_user_defined_inputs > max_fragment_shader_input_variables {
1556                    return Err(StageError::TooManyUserDefinedFragmentInputs {
1557                        num_found: num_user_defined_inputs,
1558                        limit: self.limits.max_inter_stage_shader_variables,
1559                        deductions: deductions.collect(),
1560                    });
1561                }
1562
1563                for output in &entry_point.outputs {
1564                    let &Varying::Local { location, ref iv } = output else {
1565                        continue;
1566                    };
1567                    if location >= self.limits.max_color_attachments {
1568                        return Err(StageError::ColorAttachmentLocationTooLarge {
1569                            location,
1570                            var: iv.clone(),
1571                            limit: self.limits.max_color_attachments,
1572                        });
1573                    }
1574                }
1575
1576                // If the pipeline uses dual-source blending, then the shader
1577                // must configure appropriate I/O, but it is not an error to
1578                // use a shader that defines the I/O in a pipeline that only
1579                // uses one blend source.
1580                if dual_source_blending && !entry_point.dual_source_blending {
1581                    return Err(StageError::InvalidDualSourceBlending);
1582                }
1583
1584                if entry_point
1585                    .outputs
1586                    .contains(&Varying::BuiltIn(naga::BuiltIn::FragDepth))
1587                    && !has_depth_attachment
1588                {
1589                    return Err(StageError::MissingFragDepthAttachment);
1590                }
1591            }
1592            ShaderStageForValidation::Mesh => {
1593                for output in &entry_point.outputs {
1594                    if matches!(output, Varying::BuiltIn(naga::BuiltIn::PrimitiveIndex)) {
1595                        this_stage_primitive_index = true;
1596                    }
1597                }
1598            }
1599            _ => (),
1600        }
1601
1602        if let Some(ref mesh_info) = entry_point.mesh_info {
1603            if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1604                return Err(StageError::TooManyMeshVertices {
1605                    limit: self.limits.max_mesh_output_vertices,
1606                    value: mesh_info.max_vertices,
1607                });
1608            }
1609            if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1610                return Err(StageError::TooManyMeshPrimitives {
1611                    limit: self.limits.max_mesh_output_primitives,
1612                    value: mesh_info.max_primitives,
1613                });
1614            }
1615        }
1616        if let Some(task_payload_size) = entry_point.task_payload_size {
1617            if task_payload_size > self.limits.max_task_payload_size {
1618                return Err(StageError::TaskPayloadTooLarge {
1619                    limit: self.limits.max_task_payload_size,
1620                    value: task_payload_size,
1621                });
1622            }
1623        }
1624        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1625            && entry_point.task_payload_size != inputs.task_payload_size
1626        {
1627            return Err(StageError::TaskPayloadMustMatch {
1628                input: inputs.task_payload_size,
1629                shader: entry_point.task_payload_size,
1630            });
1631        }
1632
1633        // Fragment shader primitive index is treated like a varying
1634        if shader_stage.to_naga() == naga::ShaderStage::Fragment
1635            && this_stage_primitive_index
1636            && inputs.primitive_index == Some(false)
1637        {
1638            return Err(StageError::InvalidPrimitiveIndex);
1639        } else if shader_stage.to_naga() == naga::ShaderStage::Fragment
1640            && !this_stage_primitive_index
1641            && inputs.primitive_index == Some(true)
1642        {
1643            return Err(StageError::MissingPrimitiveIndex);
1644        }
1645        if shader_stage.to_naga() == naga::ShaderStage::Mesh
1646            && inputs.task_payload_size.is_some()
1647            && has_draw_id
1648        {
1649            return Err(StageError::DrawIdError);
1650        }
1651
1652        let outputs = entry_point
1653            .outputs
1654            .iter()
1655            .filter_map(|output| match *output {
1656                Varying::Local { location, ref iv } => Some((location, iv.clone())),
1657                Varying::BuiltIn(_) => None,
1658            })
1659            .collect();
1660
1661        Ok(StageIo {
1662            task_payload_size: entry_point.task_payload_size,
1663            varyings: outputs,
1664            primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
1665                Some(this_stage_primitive_index)
1666            } else {
1667                None
1668            },
1669        })
1670    }
1671
1672    pub fn fragment_uses_dual_source_blending(
1673        &self,
1674        entry_point_name: &str,
1675    ) -> Result<bool, StageError> {
1676        let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1677        self.entry_points
1678            .get(&pair)
1679            .ok_or(StageError::MissingEntryPoint(pair.1))
1680            .map(|ep| ep.dual_source_blending)
1681    }
1682}
1683
1684/// Validate a list of color attachment formats against `maxColorAttachmentBytesPerSample`.
1685///
1686/// The color attachments can be from a render pass descriptor or a pipeline descriptor.
1687///
1688/// Implements <https://gpuweb.github.io/gpuweb/#abstract-opdef-calculating-color-attachment-bytes-per-sample>.
1689pub fn validate_color_attachment_bytes_per_sample(
1690    attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1691    limit: u32,
1692) -> Result<(), crate::command::ColorAttachmentError> {
1693    let mut total_bytes_per_sample: u32 = 0;
1694    for format in attachment_formats {
1695        let byte_cost = format.target_pixel_byte_cost().unwrap();
1696        let alignment = format.target_component_alignment().unwrap();
1697
1698        total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1699        total_bytes_per_sample += byte_cost;
1700    }
1701
1702    if total_bytes_per_sample > limit {
1703        return Err(
1704            crate::command::ColorAttachmentError::TooManyBytesPerSample {
1705                total: total_bytes_per_sample,
1706                limit,
1707            },
1708        );
1709    }
1710
1711    Ok(())
1712}
1713
1714pub enum ShaderStageForValidation {
1715    Vertex {
1716        topology: wgt::PrimitiveTopology,
1717        compare_function: Option<wgt::CompareFunction>,
1718    },
1719    Mesh,
1720    Fragment {
1721        dual_source_blending: bool,
1722        has_depth_attachment: bool,
1723    },
1724    Compute,
1725    Task,
1726}
1727
1728impl ShaderStageForValidation {
1729    pub fn to_naga(&self) -> naga::ShaderStage {
1730        match self {
1731            Self::Vertex { .. } => naga::ShaderStage::Vertex,
1732            Self::Mesh => naga::ShaderStage::Mesh,
1733            Self::Fragment { .. } => naga::ShaderStage::Fragment,
1734            Self::Compute => naga::ShaderStage::Compute,
1735            Self::Task => naga::ShaderStage::Task,
1736        }
1737    }
1738
1739    pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
1740        match self {
1741            Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
1742            Self::Mesh => wgt::ShaderStages::MESH,
1743            Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
1744            Self::Compute => wgt::ShaderStages::COMPUTE,
1745            Self::Task => wgt::ShaderStages::TASK,
1746        }
1747    }
1748}