tritonserver_rs/
context.rs1use std::{collections::HashMap, ffi::c_int, sync::Arc};
2
3use cuda_driver_sys::{
4 cuCtxCreate_v2, cuCtxDestroy_v2, cuCtxGetApiVersion, cuCtxPopCurrent_v2, cuCtxPushCurrent_v2,
5 cuDeviceGet, cuDeviceGetAttribute, cuDeviceGetName, cuDeviceTotalMem_v2, cuInit, CUcontext,
6 CUdevice, CUdevice_attribute,
7};
8use parking_lot::{Once, RwLock};
9
10use crate::{error::Error, from_char_array};
11
12pub fn init_cuda() -> Result<(), Error> {
14 cuda_call!(cuInit(0))
15}
16
17lazy_static::lazy_static! {
18 static ref CUDA_CONTEXTS: RwLock<HashMap<i32, Arc<Context>>> = RwLock::new(HashMap::default());
19 static ref ONCE: Once = Once::new();
20}
21
22pub fn get_context(device: i32) -> Result<Arc<Context>, Error> {
24 if let Some(ctx) = CUDA_CONTEXTS.read().get(&device) {
25 return Ok(ctx.clone());
26 }
27
28 ONCE.call_once(|| init_cuda().unwrap());
29
30 let dev = CuDevice::new(device)?;
31 log::info!(
32 "Using: {} {:.2}Gb",
33 dev.get_name().unwrap(),
34 dev.get_total_mem().unwrap() as f64 / (1_000_000_000) as f64
35 );
36
37 let arc = Arc::new(Context::new(dev, 0)?);
38 CUDA_CONTEXTS.write().insert(device, arc.clone());
39
40 Ok(arc)
41}
42
43pub struct ContextHandler<'a> {
46 _ctx: &'a Context,
47}
48
49impl Drop for ContextHandler<'_> {
50 fn drop(&mut self) {
51 let _ = cuda_call!(cuCtxPopCurrent_v2(std::ptr::null_mut()));
52 }
53}
54
55pub struct Context {
57 context: cuda_driver_sys::CUcontext,
58}
59
60unsafe impl Send for Context {}
61unsafe impl Sync for Context {}
62
63impl Context {
64 pub fn new(dev: CuDevice, flags: u32) -> Result<Context, Error> {
66 let mut ctx = Context {
67 context: std::ptr::null_mut(),
68 };
69
70 cuda_call!(cuCtxCreate_v2(
71 &mut ctx.context as *mut CUcontext,
72 flags,
73 dev.device
74 ))
75 .map(|_| ctx)
76 }
77
78 pub fn get_api_version(&self) -> Result<u32, Error> {
80 let mut ver = 0;
81 cuda_call!(cuCtxGetApiVersion(self.context, &mut ver as *mut u32)).map(|_| ver)
82 }
83
84 pub fn make_current(&self) -> Result<ContextHandler<'_>, Error> {
86 cuda_call!(cuCtxPushCurrent_v2(self.context))?;
87
88 Ok(ContextHandler { _ctx: self })
89 }
90}
91
92impl Drop for Context {
93 fn drop(&mut self) {
94 if !self.context.is_null() {
95 let _ = cuda_call!(cuCtxDestroy_v2(self.context));
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, Default)]
102pub struct CuDevice {
103 pub device: CUdevice,
104}
105
106impl CuDevice {
107 pub fn new(ordinal: c_int) -> Result<CuDevice, Error> {
109 let mut d = CuDevice { device: 0 };
110
111 cuda_call!(cuDeviceGet(&mut d.device as *mut i32, ordinal)).map(|_| d)
112 }
113
114 pub fn get_attribute(&self, attr: CUdevice_attribute) -> Result<c_int, Error> {
116 let mut pi = 0;
117
118 cuda_call!(cuDeviceGetAttribute(&mut pi as *mut i32, attr, self.device)).map(|_| pi)
119 }
120
121 pub fn get_name(&self) -> Result<String, Error> {
123 let mut name = vec![0; 256];
124
125 cuda_call!(
126 cuDeviceGetName(name.as_mut_ptr() as *mut _, 256, self.device,),
127 from_char_array(name.as_mut_ptr())
128 )
129 }
130
131 pub fn get_total_mem(&self) -> Result<usize, Error> {
133 let mut val = 0;
134
135 cuda_call!(cuDeviceTotalMem_v2(
136 &mut val as *mut usize as *mut _,
137 self.device
138 ))
139 .map(|_| val)
140 }
141}