Skip to main content

tract_gpu/
device.rs

1use std::ffi::c_void;
2use std::ops::Range;
3use std::sync::Mutex;
4
5use anyhow::{anyhow, bail};
6use downcast_rs::{Downcast, impl_downcast};
7use tract_core::dyn_clone;
8use tract_core::internal::*;
9use tract_core::value::TValue;
10
11use crate::tensor::{DeviceTensor, OwnedDeviceTensor};
12
13pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
14    fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
15    fn uninitialized_device_tensor(
16        &self,
17        shape: &[usize],
18        dt: DatumType,
19    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
20    fn uninitialized_device_exotic_tensor(
21        &self,
22        exotic_fact: Box<dyn ExoticFact>,
23    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
24    fn synchronize(&self) -> TractResult<()>;
25    #[allow(clippy::too_many_arguments)]
26    fn copy_nd(
27        &self,
28        input: &DeviceTensor,
29        input_offset: usize,
30        input_strides: &[isize],
31        output: &DeviceTensor,
32        output_offset: usize,
33        output_shape: &[usize],
34        output_strides: &[isize],
35    ) -> TractResult<()>;
36
37    /// Copy a slice along `axis` from `src[src_range]` into `dst[dst_range]`.
38    fn assign_slice(
39        &self,
40        dst: &DeviceTensor,
41        dst_range: Range<usize>,
42        src: &DeviceTensor,
43        src_range: Range<usize>,
44        axis: usize,
45    ) -> TractResult<()> {
46        let mut zone_shape: TVec<usize> = src.shape().into();
47        zone_shape[axis] = src_range.len();
48        if zone_shape.iter().product::<usize>() == 0 {
49            return Ok(());
50        }
51        let src_offset =
52            src_range.start * src.strides()[axis] as usize * src.datum_type().size_of();
53        let dst_offset =
54            dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of();
55        self.copy_nd(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides())
56    }
57
58    /// Copy from `src` into `dst` with given origins and strides.
59    #[allow(clippy::too_many_arguments)]
60    fn copy_with_origins(
61        &self,
62        zone_shape: &[usize],
63        dst: &DeviceTensor,
64        dst_origin: &[usize],
65        dst_strides: &[isize],
66        src: &DeviceTensor,
67        src_origin: &[usize],
68        src_strides: &[isize],
69    ) -> TractResult<()> {
70        if zone_shape.iter().product::<usize>() == 0 {
71            return Ok(());
72        }
73        let dt_size = src.datum_type().size_of();
74        let src_offset: usize =
75            src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
76                * dt_size;
77        let dst_offset: usize =
78            dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
79                * dt_size;
80        self.copy_nd(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides)
81    }
82
83    /// Flat memcpy of `byte_len` bytes.
84    fn flat_copy(
85        &self,
86        src: &DeviceTensor,
87        src_byte_offset: usize,
88        dst: &DeviceTensor,
89        dst_byte_offset: usize,
90        byte_len: usize,
91    ) -> TractResult<()> {
92        if byte_len == 0 {
93            return Ok(());
94        }
95        // copy_nd dispatches a typed kernel (u8/u16/u32/u64 based on datum_type),
96        // so shape and strides are in elements, not bytes.
97        let elem_size = src.datum_type().size_of();
98        ensure!(
99            byte_len.is_multiple_of(elem_size)
100                && src_byte_offset.is_multiple_of(elem_size)
101                && dst_byte_offset.is_multiple_of(elem_size),
102            "flat_copy: byte_len {byte_len}, src_offset {src_byte_offset}, \
103             dst_offset {dst_byte_offset} not aligned to element size {elem_size}"
104        );
105        self.copy_nd(
106            src,
107            src_byte_offset,
108            &[1],
109            dst,
110            dst_byte_offset,
111            &[byte_len / elem_size],
112            &[1],
113        )
114    }
115}
116
117impl_downcast!(DeviceContext);
118dyn_clone::clone_trait_object!(DeviceContext);
119
120pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
121    fn ptr(&self) -> *const c_void;
122}
123
124impl_downcast!(DeviceBuffer);
125dyn_clone::clone_trait_object!(DeviceBuffer);
126
127pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
128
129pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
130    let mut context = DEVICE_CONTEXT.lock().unwrap();
131    if context.is_none() {
132        *context = Some(curr_context);
133        Ok(())
134    } else {
135        bail!("Context is already set")
136    }
137}
138
139pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
140    let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
141    guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
142}