1use command_buffer::RecordingCommandBuffer;
6use concurrent_slotmap::SlotId;
7use graph::{CompileInfo, ExecuteError, ResourceMap, TaskGraph};
8use linear_map::LinearMap;
9use resource::{
10 AccessTypes, BufferState, Flight, HostAccessType, ImageLayoutType, ImageState, Resources,
11 SwapchainState,
12};
13use std::{
14 any::{Any, TypeId},
15 cell::Cell,
16 cmp,
17 error::Error,
18 fmt,
19 hash::{Hash, Hasher},
20 marker::PhantomData,
21 mem,
22 ops::{Deref, RangeBounds},
23 sync::Arc,
24};
25use vulkano::{
26 buffer::{Buffer, BufferContents, BufferMemory, Subbuffer},
27 command_buffer as raw,
28 device::Queue,
29 format::ClearValue,
30 image::Image,
31 render_pass::Framebuffer,
32 swapchain::Swapchain,
33 DeviceSize, ValidationError,
34};
35
36pub mod command_buffer;
37pub mod graph;
38mod linear_map;
39pub mod resource;
40
41pub unsafe fn execute(
43 queue: &Arc<Queue>,
44 resources: &Arc<Resources>,
45 flight_id: Id<Flight>,
46 task: impl FnOnce(&mut RecordingCommandBuffer<'_>, &mut TaskContext<'_>) -> TaskResult,
47 host_buffer_accesses: impl IntoIterator<Item = (Id<Buffer>, HostAccessType)>,
48 buffer_accesses: impl IntoIterator<Item = (Id<Buffer>, AccessTypes)>,
49 image_accesses: impl IntoIterator<Item = (Id<Image>, AccessTypes, ImageLayoutType)>,
50) -> Result<(), ExecuteError> {
51 #[repr(transparent)]
52 struct OnceTask<'a>(
53 &'a dyn Fn(&mut RecordingCommandBuffer<'_>, &mut TaskContext<'_>) -> TaskResult,
54 );
55
56 unsafe impl Send for OnceTask<'_> {}
59
60 unsafe impl Sync for OnceTask<'_> {}
63
64 impl Task for OnceTask<'static> {
65 type World = ();
66
67 unsafe fn execute(
68 &self,
69 cbf: &mut RecordingCommandBuffer<'_>,
70 tcx: &mut TaskContext<'_>,
71 _: &Self::World,
72 ) -> TaskResult {
73 (self.0)(cbf, tcx)
74 }
75 }
76
77 let task = Cell::new(Some(task));
78 let trampoline = move |cbf: &mut RecordingCommandBuffer<'_>, tcx: &mut TaskContext<'_>| {
79 (Cell::take(&task).unwrap())(cbf, tcx)
82 };
83
84 let mut task_graph = TaskGraph::new(resources, 1, 64 * 1024);
85
86 for (id, access_type) in host_buffer_accesses {
87 task_graph.add_host_buffer_access(id, access_type);
88 }
89
90 let mut node = task_graph.create_task_node(
91 "",
92 QueueFamilyType::Specific {
93 index: queue.queue_family_index(),
94 },
95 unsafe { mem::transmute::<OnceTask<'_>, OnceTask<'static>>(OnceTask(&trampoline)) },
98 );
99
100 for (id, access_types) in buffer_accesses {
101 node.buffer_access(id, access_types);
102 }
103
104 for (id, access_types, layout_type) in image_accesses {
105 node.image_access(id, access_types, layout_type);
106 }
107
108 let task_graph = unsafe {
112 task_graph.compile(&CompileInfo {
113 queues: &[queue],
114 present_queue: None,
115 flight_id,
116 _ne: crate::NE,
117 })
118 }
119 .unwrap();
120
121 let resource_map = ResourceMap::new(&task_graph).unwrap();
122
123 unsafe { task_graph.execute(resource_map, &(), || {}) }
126}
127
128pub trait Task: Any + Send + Sync {
130 type World: ?Sized;
131
132 #[allow(unused)]
146 fn clear_values(&self, clear_values: &mut ClearValues<'_>) {}
147
148 unsafe fn execute(
162 &self,
163 cbf: &mut RecordingCommandBuffer<'_>,
164 tcx: &mut TaskContext<'_>,
165 world: &Self::World,
166 ) -> TaskResult;
167}
168
169impl<W: ?Sized + 'static> dyn Task<World = W> {
170 #[inline]
172 pub fn is<T: Task<World = W>>(&self) -> bool {
173 self.type_id() == TypeId::of::<T>()
174 }
175
176 #[inline]
178 pub fn downcast_ref<T: Task<World = W>>(&self) -> Option<&T> {
179 if self.is::<T>() {
180 Some(unsafe { self.downcast_unchecked_ref() })
182 } else {
183 None
184 }
185 }
186
187 #[inline]
189 pub fn downcast_mut<T: Task<World = W>>(&mut self) -> Option<&mut T> {
190 if self.is::<T>() {
191 Some(unsafe { self.downcast_unchecked_mut() })
193 } else {
194 None
195 }
196 }
197
198 #[inline]
204 pub unsafe fn downcast_unchecked_ref<T: Task<World = W>>(&self) -> &T {
205 unsafe { &*<*const dyn Task<World = W>>::cast::<T>(self) }
207 }
208
209 #[inline]
215 pub unsafe fn downcast_unchecked_mut<T: Task<World = W>>(&mut self) -> &mut T {
216 unsafe { &mut *<*mut dyn Task<World = W>>::cast::<T>(self) }
218 }
219}
220
221impl<W: ?Sized> fmt::Debug for dyn Task<World = W> {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 f.debug_struct("Task").finish_non_exhaustive()
224 }
225}
226
227impl<W: ?Sized + 'static> Task for PhantomData<fn() -> W> {
233 type World = W;
234
235 unsafe fn execute(
236 &self,
237 _cbf: &mut RecordingCommandBuffer<'_>,
238 _tcx: &mut TaskContext<'_>,
239 _world: &Self::World,
240 ) -> TaskResult {
241 Ok(())
242 }
243}
244
245pub struct TaskContext<'a> {
249 resource_map: &'a ResourceMap<'a>,
250 current_frame_index: u32,
251 command_buffers: &'a mut Vec<Arc<raw::CommandBuffer>>,
252}
253
254impl<'a> TaskContext<'a> {
255 #[inline]
257 pub fn buffer(&self, id: Id<Buffer>) -> TaskResult<&'a BufferState> {
258 if id.is_virtual() {
259 Ok(unsafe { self.resource_map.buffer(id) }?)
262 } else {
263 Ok(unsafe { self.resource_map.resources().buffer_unprotected(id) }?)
265 }
266 }
267
268 #[inline]
274 pub fn image(&self, id: Id<Image>) -> TaskResult<&'a ImageState> {
275 assert_ne!(id.object_type(), ObjectType::Swapchain);
276
277 if id.is_virtual() {
278 Ok(unsafe { self.resource_map.image(id) }?)
281 } else {
282 Ok(unsafe { self.resource_map.resources().image_unprotected(id) }?)
284 }
285 }
286
287 #[inline]
289 pub fn swapchain(&self, id: Id<Swapchain>) -> TaskResult<&'a SwapchainState> {
290 if id.is_virtual() {
291 Ok(unsafe { self.resource_map.swapchain(id) }?)
294 } else {
295 Ok(unsafe { self.resource_map.resources().swapchain_unprotected(id) }?)
297 }
298 }
299
300 #[inline]
302 pub fn resource_map(&self) -> &'a ResourceMap<'a> {
303 self.resource_map
304 }
305
306 #[inline]
308 #[must_use]
309 pub fn current_frame_index(&self) -> u32 {
310 self.current_frame_index
311 }
312
313 pub fn read_buffer<T: BufferContents + ?Sized>(
330 &self,
331 id: Id<Buffer>,
332 range: impl RangeBounds<DeviceSize>,
333 ) -> TaskResult<&T> {
334 self.validate_read_buffer(id)?;
335
336 unsafe { self.read_buffer_unchecked(id, range) }
341 }
342
343 fn validate_read_buffer(&self, id: Id<Buffer>) -> Result<(), Box<ValidationError>> {
344 if !self
345 .resource_map
346 .virtual_resources()
347 .contains_host_buffer_access(id, HostAccessType::Read)
348 {
349 return Err(Box::new(ValidationError {
350 context: "TaskContext::read_buffer".into(),
351 problem: "the task graph does not have an access of type `HostAccessType::Read` \
352 for the buffer"
353 .into(),
354 ..Default::default()
355 }));
356 }
357
358 Ok(())
359 }
360
361 pub unsafe fn read_buffer_unchecked<T: BufferContents + ?Sized>(
380 &self,
381 id: Id<Buffer>,
382 range: impl RangeBounds<DeviceSize>,
383 ) -> TaskResult<&T> {
384 assert!(T::LAYOUT.alignment().as_devicesize() <= 64);
385
386 let buffer = self.buffer(id)?.buffer();
387 let subbuffer = Subbuffer::from(buffer.clone())
388 .slice(range)
389 .reinterpret::<T>();
390
391 let allocation = match buffer.memory() {
392 BufferMemory::Normal(a) => a,
393 BufferMemory::Sparse => {
394 todo!("`TaskContext::read_buffer` doesn't support sparse binding yet");
395 }
396 BufferMemory::External => {
397 return Err(TaskError::HostAccess(HostAccessError::Unmanaged));
398 }
399 _ => unreachable!(),
400 };
401
402 unsafe { allocation.mapped_slice_unchecked(..) }.map_err(|err| match err {
403 vulkano::sync::HostAccessError::NotHostMapped => HostAccessError::NotHostMapped,
404 vulkano::sync::HostAccessError::OutOfMappedRange => HostAccessError::OutOfMappedRange,
405 _ => unreachable!(),
406 })?;
407
408 let mapped_slice = subbuffer.mapped_slice().unwrap();
409
410 let data_ptr = unsafe { T::ptr_from_slice(mapped_slice) };
412 let data = unsafe { &*data_ptr };
413
414 Ok(data)
415 }
416
417 pub fn write_buffer<T: BufferContents + ?Sized>(
434 &mut self,
435 id: Id<Buffer>,
436 range: impl RangeBounds<DeviceSize>,
437 ) -> TaskResult<&mut T> {
438 self.validate_write_buffer(id)?;
439
440 unsafe { self.write_buffer_unchecked(id, range) }
445 }
446
447 fn validate_write_buffer(&self, id: Id<Buffer>) -> Result<(), Box<ValidationError>> {
448 if !self
449 .resource_map
450 .virtual_resources()
451 .contains_host_buffer_access(id, HostAccessType::Write)
452 {
453 return Err(Box::new(ValidationError {
454 context: "TaskContext::write_buffer".into(),
455 problem: "the task graph does not have an access of type `HostAccessType::Write` \
456 for the buffer"
457 .into(),
458 ..Default::default()
459 }));
460 }
461
462 Ok(())
463 }
464
465 pub unsafe fn write_buffer_unchecked<T: BufferContents + ?Sized>(
484 &mut self,
485 id: Id<Buffer>,
486 range: impl RangeBounds<DeviceSize>,
487 ) -> TaskResult<&mut T> {
488 assert!(T::LAYOUT.alignment().as_devicesize() <= 64);
489
490 let buffer = self.buffer(id)?.buffer();
491 let subbuffer = Subbuffer::from(buffer.clone())
492 .slice(range)
493 .reinterpret::<T>();
494
495 let allocation = match buffer.memory() {
496 BufferMemory::Normal(a) => a,
497 BufferMemory::Sparse => {
498 todo!("`TaskContext::write_buffer` doesn't support sparse binding yet");
499 }
500 BufferMemory::External => {
501 return Err(TaskError::HostAccess(HostAccessError::Unmanaged));
502 }
503 _ => unreachable!(),
504 };
505
506 unsafe { allocation.mapped_slice_unchecked(..) }.map_err(|err| match err {
507 vulkano::sync::HostAccessError::NotHostMapped => HostAccessError::NotHostMapped,
508 vulkano::sync::HostAccessError::OutOfMappedRange => HostAccessError::OutOfMappedRange,
509 _ => unreachable!(),
510 })?;
511
512 let mapped_slice = subbuffer.mapped_slice().unwrap();
513
514 let data_ptr = unsafe { T::ptr_from_slice(mapped_slice) };
516 let data = unsafe { &mut *data_ptr };
517
518 Ok(data)
519 }
520
521 #[inline]
534 pub unsafe fn push_command_buffer(&mut self, command_buffer: Arc<raw::CommandBuffer>) {
535 self.command_buffers.push(command_buffer);
536 }
537
538 #[inline]
549 pub unsafe fn extend_command_buffers(
550 &mut self,
551 command_buffers: impl IntoIterator<Item = Arc<raw::CommandBuffer>>,
552 ) {
553 self.command_buffers.extend(command_buffers);
554 }
555}
556
557pub struct ClearValues<'a> {
564 inner: &'a mut LinearMap<Id, Option<ClearValue>>,
565 resource_map: &'a ResourceMap<'a>,
566}
567
568impl ClearValues<'_> {
569 #[inline]
571 pub fn set(&mut self, id: Id<Image>, clear_value: impl Into<ClearValue>) {
572 self.set_inner(id, clear_value.into());
573 }
574
575 fn set_inner(&mut self, id: Id<Image>, clear_value: ClearValue) {
576 let mut id = id.erase();
577
578 if !id.is_virtual() {
579 let virtual_resources = self.resource_map.virtual_resources();
580
581 if let Some(&virtual_id) = virtual_resources.physical_map().get(&id.erase()) {
582 id = virtual_id;
583 } else {
584 return;
585 }
586 }
587
588 if let Some(value) = self.inner.get_mut(&id) {
589 if value.is_none() {
590 *value = Some(clear_value);
591 }
592 }
593 }
594}
595
596pub type TaskResult<T = (), E = TaskError> = ::std::result::Result<T, E>;
598
599#[derive(Debug)]
601pub enum TaskError {
602 InvalidSlot(InvalidSlotError),
603 HostAccess(HostAccessError),
604 ValidationError(Box<ValidationError>),
605}
606
607impl From<InvalidSlotError> for TaskError {
608 fn from(err: InvalidSlotError) -> Self {
609 Self::InvalidSlot(err)
610 }
611}
612
613impl From<HostAccessError> for TaskError {
614 fn from(err: HostAccessError) -> Self {
615 Self::HostAccess(err)
616 }
617}
618
619impl From<Box<ValidationError>> for TaskError {
620 fn from(err: Box<ValidationError>) -> Self {
621 Self::ValidationError(err)
622 }
623}
624
625impl fmt::Display for TaskError {
626 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
627 let msg = match self {
628 Self::InvalidSlot(_) => "invalid slot",
629 Self::HostAccess(_) => "a host access error occurred",
630 Self::ValidationError(_) => "a validation error occurred",
631 };
632
633 f.write_str(msg)
634 }
635}
636
637impl Error for TaskError {
638 fn source(&self) -> Option<&(dyn Error + 'static)> {
639 match self {
640 Self::InvalidSlot(err) => Some(err),
641 Self::HostAccess(err) => Some(err),
642 Self::ValidationError(err) => Some(err),
643 }
644 }
645}
646
647#[derive(Debug)]
649pub struct InvalidSlotError {
650 id: Id,
651}
652
653impl InvalidSlotError {
654 fn new<O>(id: Id<O>) -> Self {
655 InvalidSlotError { id: id.erase() }
656 }
657}
658
659impl fmt::Display for InvalidSlotError {
660 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
661 let &InvalidSlotError { id } = self;
662 let object_type = id.object_type();
663
664 write!(f, "invalid slot for object type `{object_type:?}`: {id:?}")
665 }
666}
667
668impl Error for InvalidSlotError {}
669
670#[derive(Debug)]
672pub enum HostAccessError {
673 Unmanaged,
674 NotHostMapped,
675 OutOfMappedRange,
676}
677
678impl fmt::Display for HostAccessError {
679 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
680 let msg = match self {
681 Self::Unmanaged => "the resource is not managed by vulkano",
682 Self::NotHostMapped => "the device memory is not current host-mapped",
683 Self::OutOfMappedRange => {
684 "the requested range is not within the currently mapped range of device memory"
685 }
686 };
687
688 f.write_str(msg)
689 }
690}
691
692impl Error for HostAccessError {}
693
694#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
696#[non_exhaustive]
697pub enum QueueFamilyType {
698 Graphics,
700
701 Compute,
703
704 Transfer,
706
707 Specific { index: u32 },
718}
719
720#[repr(transparent)]
727pub struct Id<T = ()> {
728 slot: SlotId,
729 marker: PhantomData<fn() -> T>,
730}
731
732impl<T> Id<T> {
733 pub const INVALID: Self = Id {
735 slot: SlotId::INVALID,
736 marker: PhantomData,
737 };
738
739 const unsafe fn new(slot: SlotId) -> Self {
740 Id {
741 slot,
742 marker: PhantomData,
743 }
744 }
745
746 fn index(self) -> u32 {
747 self.slot.index()
748 }
749
750 #[inline]
752 pub const fn is_virtual(self) -> bool {
753 self.slot.tag() & Id::VIRTUAL_BIT != 0
754 }
755
756 fn is_exclusive(self) -> bool {
758 self.slot.tag() & Id::EXCLUSIVE_BIT != 0
759 }
760
761 fn erase(self) -> Id {
762 unsafe { Id::new(self.slot) }
763 }
764
765 fn object_type(self) -> ObjectType {
766 match self.slot.tag() & Id::OBJECT_TYPE_MASK {
767 Buffer::TAG => ObjectType::Buffer,
768 Image::TAG => ObjectType::Image,
769 Swapchain::TAG => ObjectType::Swapchain,
770 Flight::TAG => ObjectType::Flight,
771 _ => unreachable!(),
772 }
773 }
774}
775
776impl Id<Swapchain> {
777 #[inline]
780 pub const fn current_image_id(self) -> Id<Image> {
781 unsafe { Id::new(self.slot) }
782 }
783}
784
785impl Id {
786 const OBJECT_TYPE_MASK: u32 = 0b111;
787
788 const VIRTUAL_BIT: u32 = 1 << 7;
789 const EXCLUSIVE_BIT: u32 = 1 << 6;
790
791 fn is<O: Object>(self) -> bool {
792 self.object_type() == O::TYPE
793 }
794
795 unsafe fn parametrize<O: Object>(self) -> Id<O> {
796 unsafe { Id::new(self.slot) }
797 }
798}
799
800impl<T> Clone for Id<T> {
801 #[inline]
802 fn clone(&self) -> Self {
803 *self
804 }
805}
806
807impl<T> Copy for Id<T> {}
808
809impl<T> fmt::Debug for Id<T> {
810 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
811 if *self == Id::INVALID {
812 f.pad("Id::INVALID")
813 } else {
814 f.debug_struct("Id")
815 .field("index", &self.slot.index())
816 .field("generation", &self.slot.generation())
817 .finish()
818 }
819 }
820}
821
822impl<T> PartialEq for Id<T> {
823 #[inline]
824 fn eq(&self, other: &Self) -> bool {
825 self.slot == other.slot
826 }
827}
828
829impl<T> Eq for Id<T> {}
830
831impl<T> Hash for Id<T> {
832 #[inline]
833 fn hash<H: Hasher>(&self, state: &mut H) {
834 self.slot.hash(state);
835 }
836}
837
838impl<T> PartialOrd for Id<T> {
839 #[inline]
840 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
841 Some(self.cmp(other))
842 }
843}
844
845impl<T> Ord for Id<T> {
846 #[inline]
847 fn cmp(&self, other: &Self) -> cmp::Ordering {
848 self.slot.cmp(&other.slot)
849 }
850}
851
852pub struct Ref<'a, T>(concurrent_slotmap::Ref<'a, T>);
858
859impl<T> Deref for Ref<'_, T> {
860 type Target = T;
861
862 #[inline]
863 fn deref(&self) -> &Self::Target {
864 &self.0
865 }
866}
867
868impl<T: fmt::Debug> fmt::Debug for Ref<'_, T> {
869 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
870 fmt::Debug::fmt(&self.0, f)
871 }
872}
873
874trait Object {
875 const TYPE: ObjectType;
876
877 const TAG: u32 = Self::TYPE as u32;
878}
879
880impl Object for Buffer {
881 const TYPE: ObjectType = ObjectType::Buffer;
882}
883
884impl Object for Image {
885 const TYPE: ObjectType = ObjectType::Image;
886}
887
888impl Object for Swapchain {
889 const TYPE: ObjectType = ObjectType::Swapchain;
890}
891
892impl Object for Flight {
893 const TYPE: ObjectType = ObjectType::Flight;
894}
895
896impl Object for Framebuffer {
897 const TYPE: ObjectType = ObjectType::Framebuffer;
898}
899
900#[derive(Clone, Copy, Debug, PartialEq, Eq)]
901enum ObjectType {
902 Buffer = 0,
903 Image = 1,
904 Swapchain = 2,
905 Flight = 3,
906 Framebuffer = 4,
907}
908
909#[derive(Clone, Copy, PartialEq, Eq)]
910pub struct NonExhaustive<'a>(PhantomData<&'a ()>);
911
912impl fmt::Debug for NonExhaustive<'_> {
913 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
914 f.pad("NonExhaustive")
915 }
916}
917
918const NE: NonExhaustive<'static> = NonExhaustive(PhantomData);
919
920#[cfg(test)]
921mod tests {
922 macro_rules! test_queues {
923 () => {{
924 let Ok(library) = vulkano::VulkanLibrary::new() else {
925 return;
926 };
927 let Ok(instance) = vulkano::instance::Instance::new(library, Default::default()) else {
928 return;
929 };
930 let Ok(mut physical_devices) = instance.enumerate_physical_devices() else {
931 return;
932 };
933 let Some(physical_device) = physical_devices.find(|p| {
934 p.queue_family_properties().iter().any(|q| {
935 q.queue_flags
936 .contains(vulkano::device::QueueFlags::GRAPHICS)
937 })
938 }) else {
939 return;
940 };
941 let queue_create_infos = physical_device
942 .queue_family_properties()
943 .iter()
944 .enumerate()
945 .map(|(i, _)| vulkano::device::QueueCreateInfo {
946 queue_family_index: i as u32,
947 ..Default::default()
948 })
949 .collect();
950 let Ok((device, queues)) = vulkano::device::Device::new(
951 physical_device,
952 vulkano::device::DeviceCreateInfo {
953 queue_create_infos,
954 ..Default::default()
955 },
956 ) else {
957 return;
958 };
959
960 (
961 $crate::resource::Resources::new(&device, &Default::default()),
962 queues.collect::<Vec<_>>(),
963 )
964 }};
965 }
966 pub(crate) use test_queues;
967}