Skip to main content

tract_gpu/
device.rs

1use std::ffi::c_void;
2use std::sync::Mutex;
3
4use anyhow::{anyhow, bail};
5use downcast_rs::{Downcast, impl_downcast};
6use tract_core::dyn_clone;
7use tract_core::internal::OpaqueFact;
8use tract_core::prelude::{DatumType, TractResult};
9use tract_core::value::TValue;
10
11use crate::tensor::OwnedDeviceTensor;
12
13pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
14    fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
15    fn uninitialized_device_tensor(
16        &self,
17        shape: &[usize],
18        dt: DatumType,
19    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
20    fn uninitialized_device_opaque_tensor(
21        &self,
22        opaque_fact: Box<dyn OpaqueFact>,
23    ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
24    fn synchronize(&self) -> TractResult<()>;
25}
26
27impl_downcast!(DeviceContext);
28dyn_clone::clone_trait_object!(DeviceContext);
29
30pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
31    fn ptr(&self) -> *const c_void;
32}
33
34impl_downcast!(DeviceBuffer);
35dyn_clone::clone_trait_object!(DeviceBuffer);
36
37pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
38
39pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
40    let mut context = DEVICE_CONTEXT.lock().unwrap();
41    if context.is_none() {
42        *context = Some(curr_context);
43        Ok(())
44    } else {
45        bail!("Context is already set")
46    }
47}
48
49pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
50    let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
51    guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
52}