1use alloc::{boxed::Box, string::String, vec::Vec};
2use core::{fmt, num::NonZeroU32};
3
4use crate::{
5 binding_model,
6 hub::Hub,
7 id::{BindGroupLayoutId, PipelineLayoutId},
8 resource::{
9 Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
10 ResourceErrorIdent,
11 },
12 snatch::SnatchGuard,
13 Label, DOWNLEVEL_ERROR_MESSAGE,
14};
15
16use arrayvec::ArrayVec;
17use smallvec::SmallVec;
18use thiserror::Error;
19use wgt::{BufferAddress, DeviceLostReason, TextureFormat};
20
21pub(crate) mod bgl;
22pub mod global;
23mod life;
24pub mod queue;
25pub mod ray_tracing;
26pub mod resource;
27#[cfg(any(feature = "trace", feature = "replay"))]
28pub mod trace;
29pub use {life::WaitIdleError, resource::Device};
30
31pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
32pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
35
36const CLEANUP_WAIT_MS: u32 = 60000;
39
40pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
41
42pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
43
44#[repr(C)]
45#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47pub enum HostMap {
48 Read,
49 Write,
50}
51
52#[derive(Clone, Debug, Hash, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
54pub(crate) struct AttachmentData<T> {
55 pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
56 pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
57 pub depth_stencil: Option<T>,
58}
59impl<T: PartialEq> Eq for AttachmentData<T> {}
60
61#[derive(Clone, Debug, Hash, PartialEq)]
62#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
63pub(crate) struct RenderPassContext {
64 pub attachments: AttachmentData<TextureFormat>,
65 pub sample_count: u32,
66 pub multiview: Option<NonZeroU32>,
67}
68#[derive(Clone, Debug, Error)]
69#[non_exhaustive]
70pub enum RenderPassCompatibilityError {
71 #[error(
72 "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
73 )]
74 IncompatibleColorAttachment {
75 indices: Vec<usize>,
76 expected: Vec<Option<TextureFormat>>,
77 actual: Vec<Option<TextureFormat>>,
78 res: ResourceErrorIdent,
79 },
80 #[error(
81 "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
82 )]
83 IncompatibleDepthStencilAttachment {
84 expected: Option<TextureFormat>,
85 actual: Option<TextureFormat>,
86 res: ResourceErrorIdent,
87 },
88 #[error(
89 "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
90 )]
91 IncompatibleSampleCount {
92 expected: u32,
93 actual: u32,
94 res: ResourceErrorIdent,
95 },
96 #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
97 IncompatibleMultiview {
98 expected: Option<NonZeroU32>,
99 actual: Option<NonZeroU32>,
100 res: ResourceErrorIdent,
101 },
102}
103
104impl RenderPassContext {
105 pub(crate) fn check_compatible<T: Labeled>(
107 &self,
108 other: &Self,
109 res: &T,
110 ) -> Result<(), RenderPassCompatibilityError> {
111 if self.attachments.colors != other.attachments.colors {
112 let indices = self
113 .attachments
114 .colors
115 .iter()
116 .zip(&other.attachments.colors)
117 .enumerate()
118 .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
119 .collect();
120 return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
121 indices,
122 expected: self.attachments.colors.iter().cloned().collect(),
123 actual: other.attachments.colors.iter().cloned().collect(),
124 res: res.error_ident(),
125 });
126 }
127 if self.attachments.depth_stencil != other.attachments.depth_stencil {
128 return Err(
129 RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
130 expected: self.attachments.depth_stencil,
131 actual: other.attachments.depth_stencil,
132 res: res.error_ident(),
133 },
134 );
135 }
136 if self.sample_count != other.sample_count {
137 return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
138 expected: self.sample_count,
139 actual: other.sample_count,
140 res: res.error_ident(),
141 });
142 }
143 if self.multiview != other.multiview {
144 return Err(RenderPassCompatibilityError::IncompatibleMultiview {
145 expected: self.multiview,
146 actual: other.multiview,
147 res: res.error_ident(),
148 });
149 }
150 Ok(())
151 }
152}
153
154pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
155
156#[derive(Default)]
157pub struct UserClosures {
158 pub mappings: Vec<BufferMapPendingClosure>,
159 pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
160 pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
161}
162
163impl UserClosures {
164 fn extend(&mut self, other: Self) {
165 self.mappings.extend(other.mappings);
166 self.submissions.extend(other.submissions);
167 self.device_lost_invocations
168 .extend(other.device_lost_invocations);
169 }
170
171 fn fire(self) {
172 for (mut operation, status) in self.mappings {
178 if let Some(callback) = operation.callback.take() {
179 callback(status);
180 }
181 }
182 for closure in self.submissions {
183 closure();
184 }
185 for invocation in self.device_lost_invocations {
186 (invocation.closure)(invocation.reason, invocation.message);
187 }
188 }
189}
190
191#[cfg(send_sync)]
192pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
193#[cfg(not(send_sync))]
194pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
195
196pub struct DeviceLostInvocation {
197 closure: DeviceLostClosure,
198 reason: DeviceLostReason,
199 message: String,
200}
201
202pub(crate) fn map_buffer(
203 buffer: &Buffer,
204 offset: BufferAddress,
205 size: BufferAddress,
206 kind: HostMap,
207 snatch_guard: &SnatchGuard,
208) -> Result<hal::BufferMapping, BufferAccessError> {
209 let raw_device = buffer.device.raw();
210 let raw_buffer = buffer.try_raw(snatch_guard)?;
211 let mapping = unsafe {
212 raw_device
213 .map_buffer(raw_buffer, offset..offset + size)
214 .map_err(|e| buffer.device.handle_hal_error(e))?
215 };
216
217 if !mapping.is_coherent && kind == HostMap::Read {
218 #[allow(clippy::single_range_in_vec_init)]
219 unsafe {
220 raw_device.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
221 }
222 }
223
224 assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
225 assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
226 let mapped = unsafe { core::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
241
242 if !mapping.is_coherent
244 && kind == HostMap::Read
245 && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
246 {
247 for uninitialized in buffer
248 .initialization_status
249 .write()
250 .uninitialized(offset..(size + offset))
251 {
252 let fill_range =
255 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
256 mapped[fill_range].fill(0);
257 }
258 } else {
259 for uninitialized in buffer
260 .initialization_status
261 .write()
262 .drain(offset..(size + offset))
263 {
264 let fill_range =
267 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
268 mapped[fill_range].fill(0);
269
270 if !mapping.is_coherent
272 && kind == HostMap::Read
273 && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
274 {
275 unsafe { raw_device.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
276 }
277 }
278 }
279
280 Ok(mapping)
281}
282
283#[derive(Clone, Debug)]
284#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
285pub struct DeviceMismatch {
286 pub(super) res: ResourceErrorIdent,
287 pub(super) res_device: ResourceErrorIdent,
288 pub(super) target: Option<ResourceErrorIdent>,
289 pub(super) target_device: ResourceErrorIdent,
290}
291
292impl fmt::Display for DeviceMismatch {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
294 write!(
295 f,
296 "{} of {} doesn't match {}",
297 self.res_device, self.res, self.target_device
298 )?;
299 if let Some(target) = self.target.as_ref() {
300 write!(f, " of {target}")?;
301 }
302 Ok(())
303 }
304}
305
306impl core::error::Error for DeviceMismatch {}
307
308#[derive(Clone, Debug, Error)]
309#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
310#[non_exhaustive]
311pub enum DeviceError {
312 #[error("{0} is invalid.")]
313 Invalid(ResourceErrorIdent),
314 #[error("Parent device is lost")]
315 Lost,
316 #[error("Not enough memory left.")]
317 OutOfMemory,
318 #[error("Creation of a resource failed for a reason other than running out of memory.")]
319 ResourceCreationFailed,
320 #[error(transparent)]
321 DeviceMismatch(#[from] Box<DeviceMismatch>),
322}
323
324impl DeviceError {
325 pub fn from_hal(error: hal::DeviceError) -> Self {
329 match error {
330 hal::DeviceError::Lost => Self::Lost,
331 hal::DeviceError::OutOfMemory => Self::OutOfMemory,
332 hal::DeviceError::ResourceCreationFailed => Self::ResourceCreationFailed,
333 hal::DeviceError::Unexpected => Self::Lost,
334 }
335 }
336}
337
338#[derive(Clone, Debug, Error)]
339#[error("Features {0:?} are required but not enabled on the device")]
340pub struct MissingFeatures(pub wgt::Features);
341
342#[derive(Clone, Debug, Error)]
343#[error(
344 "Downlevel flags {0:?} are required but not supported on the device.\n{DOWNLEVEL_ERROR_MESSAGE}",
345)]
346pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
347
348#[derive(Clone, Debug)]
349#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
350pub struct ImplicitPipelineContext {
351 pub root_id: PipelineLayoutId,
352 pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
353}
354
355pub struct ImplicitPipelineIds<'a> {
356 pub root_id: PipelineLayoutId,
357 pub group_ids: &'a [BindGroupLayoutId],
358}
359
360impl ImplicitPipelineIds<'_> {
361 fn prepare(self, hub: &Hub) -> ImplicitPipelineContext {
362 ImplicitPipelineContext {
363 root_id: hub.pipeline_layouts.prepare(Some(self.root_id)).id(),
364 group_ids: self
365 .group_ids
366 .iter()
367 .map(|id_in| hub.bind_group_layouts.prepare(Some(*id_in)).id())
368 .collect(),
369 }
370 }
371}
372
373pub fn create_validator(
375 features: wgt::Features,
376 downlevel: wgt::DownlevelFlags,
377 flags: naga::valid::ValidationFlags,
378) -> naga::valid::Validator {
379 use naga::valid::Capabilities as Caps;
380 let mut caps = Caps::empty();
381 caps.set(
382 Caps::PUSH_CONSTANT,
383 features.contains(wgt::Features::PUSH_CONSTANTS),
384 );
385 caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
386 caps.set(
387 Caps::SHADER_FLOAT16,
388 features.contains(wgt::Features::SHADER_F16),
389 );
390 caps.set(
391 Caps::PRIMITIVE_INDEX,
392 features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
393 );
394 caps.set(
395 Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
396 features
397 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
398 );
399 caps.set(
400 Caps::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
401 features.contains(wgt::Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
402 );
403 caps.set(
404 Caps::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
405 features.contains(wgt::Features::UNIFORM_BUFFER_BINDING_ARRAYS),
406 );
407 caps.set(
409 Caps::SAMPLER_NON_UNIFORM_INDEXING,
410 features
411 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
412 );
413 caps.set(
414 Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
415 features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
416 );
417 caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
418 caps.set(
419 Caps::EARLY_DEPTH_TEST,
420 features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
421 );
422 caps.set(
423 Caps::SHADER_INT64,
424 features.contains(wgt::Features::SHADER_INT64),
425 );
426 caps.set(
427 Caps::SHADER_INT64_ATOMIC_MIN_MAX,
428 features.intersects(
429 wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
430 ),
431 );
432 caps.set(
433 Caps::SHADER_INT64_ATOMIC_ALL_OPS,
434 features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
435 );
436 caps.set(
437 Caps::TEXTURE_ATOMIC,
438 features.contains(wgt::Features::TEXTURE_ATOMIC),
439 );
440 caps.set(
441 Caps::TEXTURE_INT64_ATOMIC,
442 features.contains(wgt::Features::TEXTURE_INT64_ATOMIC),
443 );
444 caps.set(
445 Caps::SHADER_FLOAT32_ATOMIC,
446 features.contains(wgt::Features::SHADER_FLOAT32_ATOMIC),
447 );
448 caps.set(
449 Caps::MULTISAMPLED_SHADING,
450 downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
451 );
452 caps.set(
453 Caps::DUAL_SOURCE_BLENDING,
454 features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
455 );
456 caps.set(
457 Caps::CUBE_ARRAY_TEXTURES,
458 downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
459 );
460 caps.set(
461 Caps::SUBGROUP,
462 features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
463 );
464 caps.set(
465 Caps::SUBGROUP_BARRIER,
466 features.intersects(wgt::Features::SUBGROUP_BARRIER),
467 );
468 caps.set(
469 Caps::RAY_QUERY,
470 features.intersects(wgt::Features::EXPERIMENTAL_RAY_QUERY),
471 );
472 caps.set(
473 Caps::SUBGROUP_VERTEX_STAGE,
474 features.contains(wgt::Features::SUBGROUP_VERTEX),
475 );
476 caps.set(
477 Caps::RAY_HIT_VERTEX_POSITION,
478 features.intersects(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
479 );
480
481 naga::valid::Validator::new(flags, caps)
482}