Skip to main content

tract_gpu/
device.rs

1use std::ffi::c_void;
2use std::ops::Range;
3use std::sync::Mutex;
4
5use anyhow::{anyhow, bail};
6use downcast_rs::{Downcast, impl_downcast};
7use tract_core::dyn_clone;
8use tract_core::internal::*;
9use tract_core::value::TValue;
10
11use crate::tensor::{DeviceTensor, OwnedDeviceTensor};
12
13pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
14    fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
15    fn uninitialized_device_tensor(
16        &self,
17        shape: &[usize],
18        dt: DatumType,
19    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
20    fn uninitialized_device_exotic_tensor(
21        &self,
22        exotic_fact: Box<dyn ExoticFact>,
23    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
24    fn synchronize(&self) -> TractResult<()>;
25    fn copy_nd(
26        &self,
27        input: &DeviceTensor,
28        input_offset: usize,
29        input_strides: &[isize],
30        output: &DeviceTensor,
31        output_offset: usize,
32        output_shape: &[usize],
33        output_strides: &[isize],
34    ) -> TractResult<()>;
35
36    /// Copy a slice along `axis` from `src[src_range]` into `dst[dst_range]`.
37    fn assign_slice(
38        &self,
39        dst: &DeviceTensor,
40        dst_range: Range<usize>,
41        src: &DeviceTensor,
42        src_range: Range<usize>,
43        axis: usize,
44    ) -> TractResult<()> {
45        let mut zone_shape: TVec<usize> = src.shape().into();
46        zone_shape[axis] = src_range.len();
47        if zone_shape.iter().product::<usize>() == 0 {
48            return Ok(());
49        }
50        let src_offset =
51            src_range.start * src.strides()[axis] as usize * src.datum_type().size_of();
52        let dst_offset =
53            dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of();
54        self.copy_nd(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides())
55    }
56
57    /// Copy from `src` into `dst` with given origins and strides.
58    fn copy_with_origins(
59        &self,
60        zone_shape: &[usize],
61        dst: &DeviceTensor,
62        dst_origin: &[usize],
63        dst_strides: &[isize],
64        src: &DeviceTensor,
65        src_origin: &[usize],
66        src_strides: &[isize],
67    ) -> TractResult<()> {
68        if zone_shape.iter().product::<usize>() == 0 {
69            return Ok(());
70        }
71        let dt_size = src.datum_type().size_of();
72        let src_offset: usize =
73            src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
74                * dt_size;
75        let dst_offset: usize =
76            dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
77                * dt_size;
78        self.copy_nd(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides)
79    }
80
81    /// Flat memcpy of `byte_len` bytes.
82    fn flat_copy(
83        &self,
84        src: &DeviceTensor,
85        src_byte_offset: usize,
86        dst: &DeviceTensor,
87        dst_byte_offset: usize,
88        byte_len: usize,
89    ) -> TractResult<()> {
90        if byte_len == 0 {
91            return Ok(());
92        }
93        // copy_nd dispatches a typed kernel (u8/u16/u32/u64 based on datum_type),
94        // so shape and strides are in elements, not bytes.
95        let elem_size = src.datum_type().size_of();
96        ensure!(
97            byte_len % elem_size == 0
98                && src_byte_offset % elem_size == 0
99                && dst_byte_offset % elem_size == 0,
100            "flat_copy: byte_len {byte_len}, src_offset {src_byte_offset}, \
101             dst_offset {dst_byte_offset} not aligned to element size {elem_size}"
102        );
103        self.copy_nd(
104            src,
105            src_byte_offset,
106            &[1],
107            dst,
108            dst_byte_offset,
109            &[byte_len / elem_size],
110            &[1],
111        )
112    }
113}
114
115impl_downcast!(DeviceContext);
116dyn_clone::clone_trait_object!(DeviceContext);
117
118pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
119    fn ptr(&self) -> *const c_void;
120}
121
122impl_downcast!(DeviceBuffer);
123dyn_clone::clone_trait_object!(DeviceBuffer);
124
125pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
126
127pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
128    let mut context = DEVICE_CONTEXT.lock().unwrap();
129    if context.is_none() {
130        *context = Some(curr_context);
131        Ok(())
132    } else {
133        bail!("Context is already set")
134    }
135}
136
137pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
138    let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
139    guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
140}