Skip to main content

tract_gpu/
device.rs

1use core::fmt;
2use std::ffi::c_void;
3use std::sync::Mutex;
4
5use anyhow::{anyhow, bail};
6use downcast_rs::{impl_downcast, Downcast};
7use tract_core::dyn_clone;
8use tract_core::prelude::TractResult;
9
10pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
11    fn buffer_from_slice(&self, tensor: &[u8]) -> Box<dyn DeviceBuffer>;
12    fn synchronize(&self) -> TractResult<()>;
13}
14
15impl_downcast!(DeviceContext);
16dyn_clone::clone_trait_object!(DeviceContext);
17
18pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync {
19    fn info(&self) -> String;
20    fn ptr(&self) -> *const c_void;
21}
22
23impl_downcast!(DeviceBuffer);
24dyn_clone::clone_trait_object!(DeviceBuffer);
25
26impl fmt::Debug for dyn DeviceBuffer {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "DeviceBuffer: {:?}", self.info())
29    }
30}
31
32pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
33
34pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
35    let mut context = DEVICE_CONTEXT.lock().unwrap();
36    if context.is_none() {
37        *context = Some(curr_context);
38        Ok(())
39    } else {
40        bail!("Context is already set")
41    }
42}
43
44pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
45    let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
46    guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
47}