tritonserver_rs/
macros.rs

1use crate::Error;
2
3#[cfg(feature = "gpu")]
4/// Run cuda method and get the Result<(), tritonserver_rs::Error> instead of cuda_driver_sys::CUresult.
5macro_rules! cuda_call {
6    ($expr: expr) => {{
7        #[allow(clippy::macro_metavars_in_unsafe)]
8        let res = unsafe { $expr };
9
10        if res != cuda_driver_sys::CUresult::CUDA_SUCCESS {
11            Err($crate::error::Error::new(
12                $crate::error::ErrorCode::Internal,
13                format!("Cuda result: {:?}", res),
14            ))
15        } else {
16            std::result::Result::<_, $crate::error::Error>::Ok(())
17        }
18    }};
19    ($expr: expr, $val: expr) => {{
20        #[allow(clippy::macro_metavars_in_unsafe)]
21        let res = unsafe { $expr };
22
23        if res != cuda_driver_sys::CUresult::CUDA_SUCCESS {
24            Err($crate::error::Error::new(
25                $crate::error::ErrorCode::Internal,
26                format!("Cuda result: {:?}", res),
27            ))
28        } else {
29            std::result::Result::<_, $crate::error::Error>::Ok($val)
30        }
31    }};
32}
33
34/// Run triton method and get the Result<(), tritonserver_rs::Error> instead of cuda_driver_sys::CUresult.
35macro_rules! triton_call {
36    ($expr: expr) => {{
37        #[allow(clippy::macro_metavars_in_unsafe)]
38        let res = unsafe { $expr };
39
40        if res.is_null() {
41            std::result::Result::<(), $crate::error::Error>::Ok(())
42        } else {
43            std::result::Result::<(), $crate::error::Error>::Err(res.into())
44        }
45    }};
46    ($expr: expr, $val: expr) => {{
47        #[allow(clippy::macro_metavars_in_unsafe)]
48        let res = unsafe { $expr };
49
50        if res.is_null() {
51            std::result::Result::<_, $crate::error::Error>::Ok($val)
52        } else {
53            std::result::Result::<_, $crate::error::Error>::Err(res.into())
54        }
55    }};
56}
57
58// Next two fns in this module by historical reasons.
59
60/// Run cuda code (which should be run in sync + cuda context pinned) in asynchronous context.
61///
62/// First argument is an id of device to run function on; second is the code to run.
63///
64/// If "gpu" feature is off just runs a code without context/blocking.
65pub async fn run_in_context<T, F>(device: i32, code: F) -> Result<T, Error>
66where
67    T: Send + 'static,
68    F: FnOnce() -> T + Send + 'static,
69{
70    #[cfg(feature = "gpu")]
71    {
72        tokio::task::spawn_blocking(move || {
73            let ctx = crate::get_context(device)?;
74            let _handle = ctx.make_current()?;
75            Ok(code())
76        })
77        .await
78        .map_err(|_| {
79            Error::new(
80                crate::ErrorCode::Internal,
81                "tokio failed to join thread on run_in_context",
82            )
83        })?
84    }
85    #[cfg(not(feature = "gpu"))]
86    {
87        let _ = device;
88        Ok(code())
89    }
90}
91
92/// Run cuda code (which should be run in sync + cuda context pinned).
93///
94/// First argument is an id of device to run function on; second is the code to run.
95///
96/// If "gpu" feature is off just runs a code without context/blocking.
97pub fn run_in_context_sync<T, F: FnOnce() -> T>(device: i32, code: F) -> Result<T, Error> {
98    #[cfg(feature = "gpu")]
99    {
100        let ctx = crate::get_context(device)?;
101        let _handle = ctx.make_current()?;
102        Ok(code())
103    }
104    #[cfg(not(feature = "gpu"))]
105    {
106        let _ = device;
107        Ok(code())
108    }
109}