1use std::ffi::c_void;
2use std::ops::Range;
3use std::sync::Mutex;
4
5use anyhow::{anyhow, bail};
6use downcast_rs::{Downcast, impl_downcast};
7use tract_core::dyn_clone;
8use tract_core::internal::*;
9use tract_core::value::TValue;
10
11use crate::tensor::{DeviceTensor, OwnedDeviceTensor};
12
13pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
14 fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
15 fn uninitialized_device_tensor(
16 &self,
17 shape: &[usize],
18 dt: DatumType,
19 ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
20 fn uninitialized_device_exotic_tensor(
21 &self,
22 exotic_fact: Box<dyn ExoticFact>,
23 ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
24 fn synchronize(&self) -> TractResult<()>;
25 #[allow(clippy::too_many_arguments)]
26 fn copy_nd(
27 &self,
28 input: &DeviceTensor,
29 input_offset: usize,
30 input_strides: &[isize],
31 output: &DeviceTensor,
32 output_offset: usize,
33 output_shape: &[usize],
34 output_strides: &[isize],
35 ) -> TractResult<()>;
36
37 fn assign_slice(
39 &self,
40 dst: &DeviceTensor,
41 dst_range: Range<usize>,
42 src: &DeviceTensor,
43 src_range: Range<usize>,
44 axis: usize,
45 ) -> TractResult<()> {
46 let mut zone_shape: TVec<usize> = src.shape().into();
47 zone_shape[axis] = src_range.len();
48 if zone_shape.iter().product::<usize>() == 0 {
49 return Ok(());
50 }
51 let src_offset =
52 src_range.start * src.strides()[axis] as usize * src.datum_type().size_of();
53 let dst_offset =
54 dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of();
55 self.copy_nd(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides())
56 }
57
58 #[allow(clippy::too_many_arguments)]
60 fn copy_with_origins(
61 &self,
62 zone_shape: &[usize],
63 dst: &DeviceTensor,
64 dst_origin: &[usize],
65 dst_strides: &[isize],
66 src: &DeviceTensor,
67 src_origin: &[usize],
68 src_strides: &[isize],
69 ) -> TractResult<()> {
70 if zone_shape.iter().product::<usize>() == 0 {
71 return Ok(());
72 }
73 let dt_size = src.datum_type().size_of();
74 let src_offset: usize =
75 src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
76 * dt_size;
77 let dst_offset: usize =
78 dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
79 * dt_size;
80 self.copy_nd(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides)
81 }
82
83 fn flat_copy(
85 &self,
86 src: &DeviceTensor,
87 src_byte_offset: usize,
88 dst: &DeviceTensor,
89 dst_byte_offset: usize,
90 byte_len: usize,
91 ) -> TractResult<()> {
92 if byte_len == 0 {
93 return Ok(());
94 }
95 let elem_size = src.datum_type().size_of();
98 ensure!(
99 byte_len.is_multiple_of(elem_size)
100 && src_byte_offset.is_multiple_of(elem_size)
101 && dst_byte_offset.is_multiple_of(elem_size),
102 "flat_copy: byte_len {byte_len}, src_offset {src_byte_offset}, \
103 dst_offset {dst_byte_offset} not aligned to element size {elem_size}"
104 );
105 self.copy_nd(
106 src,
107 src_byte_offset,
108 &[1],
109 dst,
110 dst_byte_offset,
111 &[byte_len / elem_size],
112 &[1],
113 )
114 }
115}
116
117impl_downcast!(DeviceContext);
118dyn_clone::clone_trait_object!(DeviceContext);
119
120pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
121 fn ptr(&self) -> *const c_void;
122}
123
124impl_downcast!(DeviceBuffer);
125dyn_clone::clone_trait_object!(DeviceBuffer);
126
127pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
128
129pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
130 let mut context = DEVICE_CONTEXT.lock().unwrap();
131 if context.is_none() {
132 *context = Some(curr_context);
133 Ok(())
134 } else {
135 bail!("Context is already set")
136 }
137}
138
139pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
140 let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
141 guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
142}