tritonserver_rs/
context.rs

1use std::{collections::HashMap, ffi::c_int, sync::Arc};
2
3use cuda_driver_sys::{
4    cuCtxCreate_v2, cuCtxDestroy_v2, cuCtxGetApiVersion, cuCtxPopCurrent_v2, cuCtxPushCurrent_v2,
5    cuDeviceGet, cuDeviceGetAttribute, cuDeviceGetName, cuDeviceTotalMem_v2, cuInit, CUcontext,
6    CUdevice, CUdevice_attribute,
7};
8use parking_lot::{Once, RwLock};
9
10use crate::{error::Error, from_char_array};
11
12/// Initialize Cuda runtime. Should be called before any Cuda function, perfectly — on the start of the application.
13pub fn init_cuda() -> Result<(), Error> {
14    cuda_call!(cuInit(0))
15}
16
17lazy_static::lazy_static! {
18    static ref CUDA_CONTEXTS: RwLock<HashMap<i32, Arc<Context>>> = RwLock::new(HashMap::default());
19    static ref ONCE: Once = Once::new();
20}
21
22/// Get Cuda context on device.
23pub fn get_context(device: i32) -> Result<Arc<Context>, Error> {
24    if let Some(ctx) = CUDA_CONTEXTS.read().get(&device) {
25        return Ok(ctx.clone());
26    }
27
28    ONCE.call_once(|| init_cuda().unwrap());
29
30    let dev = CuDevice::new(device)?;
31    log::info!(
32        "Using: {} {:.2}Gb",
33        dev.get_name().unwrap(),
34        dev.get_total_mem().unwrap() as f64 / (1_000_000_000) as f64
35    );
36
37    let arc = Arc::new(Context::new(dev, 0)?);
38    CUDA_CONTEXTS.write().insert(device, arc.clone());
39
40    Ok(arc)
41}
42
43/// Handler of Cuda context that was pushed as current.
44/// On Drop will pop context from current.
45pub struct ContextHandler<'a> {
46    _ctx: &'a Context,
47}
48
49impl Drop for ContextHandler<'_> {
50    fn drop(&mut self) {
51        let _ = cuda_call!(cuCtxPopCurrent_v2(std::ptr::null_mut()));
52    }
53}
54
55/// Cuda Context.
56pub struct Context {
57    context: cuda_driver_sys::CUcontext,
58}
59
60unsafe impl Send for Context {}
61unsafe impl Sync for Context {}
62
63impl Context {
64    /// Create Context on device `dev`. It is recommended to use zeroed `flags`.
65    pub fn new(dev: CuDevice, flags: u32) -> Result<Context, Error> {
66        let mut ctx = Context {
67            context: std::ptr::null_mut(),
68        };
69
70        cuda_call!(cuCtxCreate_v2(
71            &mut ctx.context as *mut CUcontext,
72            flags,
73            dev.device
74        ))
75        .map(|_| ctx)
76    }
77
78    /// Get Cuda API version.
79    pub fn get_api_version(&self) -> Result<u32, Error> {
80        let mut ver = 0;
81        cuda_call!(cuCtxGetApiVersion(self.context, &mut ver as *mut u32)).map(|_| ver)
82    }
83
84    /// Make this context current.
85    pub fn make_current(&self) -> Result<ContextHandler<'_>, Error> {
86        cuda_call!(cuCtxPushCurrent_v2(self.context))?;
87
88        Ok(ContextHandler { _ctx: self })
89    }
90}
91
92impl Drop for Context {
93    fn drop(&mut self) {
94        if !self.context.is_null() {
95            let _ = cuda_call!(cuCtxDestroy_v2(self.context));
96        }
97    }
98}
99
100/// Cuda representation of the device.
101#[derive(Debug, Clone, Copy, Default)]
102pub struct CuDevice {
103    pub device: CUdevice,
104}
105
106impl CuDevice {
107    /// Create new device with id `ordinal`.
108    pub fn new(ordinal: c_int) -> Result<CuDevice, Error> {
109        let mut d = CuDevice { device: 0 };
110
111        cuda_call!(cuDeviceGet(&mut d.device as *mut i32, ordinal)).map(|_| d)
112    }
113
114    /// Get attributes of the device.
115    pub fn get_attribute(&self, attr: CUdevice_attribute) -> Result<c_int, Error> {
116        let mut pi = 0;
117
118        cuda_call!(cuDeviceGetAttribute(&mut pi as *mut i32, attr, self.device)).map(|_| pi)
119    }
120
121    /// Get name of the device.
122    pub fn get_name(&self) -> Result<String, Error> {
123        let mut name = vec![0; 256];
124
125        cuda_call!(
126            cuDeviceGetName(name.as_mut_ptr() as *mut _, 256, self.device,),
127            from_char_array(name.as_mut_ptr())
128        )
129    }
130
131    /// Get total mem of the device.
132    pub fn get_total_mem(&self) -> Result<usize, Error> {
133        let mut val = 0;
134
135        cuda_call!(cuDeviceTotalMem_v2(
136            &mut val as *mut usize as *mut _,
137            self.device
138        ))
139        .map(|_| val)
140    }
141}