singe_cusolver/context.rs
1#[allow(unused_imports)]
2use crate::{dense::*, svd::*};
3
4use std::{path::Path, ptr, sync::Arc};
5
6use singe_core::path_to_cstring;
7use singe_cuda::{context::Context as CudaContext, stream::Stream, types::EmulationStrategy};
8use singe_cuda_sys::runtime;
9
10use crate::{
11 error::{Error, Result},
12 sys, try_ffi,
13 types::{DeterministicMode, MathMode},
14};
15
16/// A stateful cuSOLVER handle.
17///
18/// Use one context per host thread or concurrent task. The handle is movable
19/// between threads, but it is intentionally not `Clone` or `Sync`.
20#[derive(Debug)]
21pub struct Context {
22 handle: Handle,
23}
24
25#[derive(Debug)]
26struct Handle {
27 raw: sys::cusolverDnHandle_t,
28 cuda_ctx: Arc<CudaContext>,
29}
30
31// cuSOLVER handles are stateful and stream-bound. The owner may move between
32// threads, but callers need exclusive access to mutate handle state.
33unsafe impl Send for Handle {}
34
35/// The stream bound to a cuSOLVER handle.
36#[derive(Debug, Clone)]
37pub enum StreamBinding {
38 /// The CUDA default stream.
39 Default,
40 /// A borrowed stream associated with the same CUDA context.
41 Borrowed(BorrowedStream),
42}
43
44/// A stream borrowed from a CUDA context, associated with a cuSOLVER handle.
45#[derive(Debug, Clone)]
46pub struct BorrowedStream {
47 handle: runtime::cudaStream_t,
48 cuda_ctx: Arc<CudaContext>,
49}
50
51impl BorrowedStream {
52 /// Returns the raw CUDA stream handle.
53 pub const fn as_raw(&self) -> runtime::cudaStream_t {
54 self.handle
55 }
56
57 /// Returns a reference to the CUDA context this stream belongs to.
58 pub fn context(&self) -> &CudaContext {
59 self.cuda_ctx.as_ref()
60 }
61}
62
63impl Context {
64 /// Creates a cuSOLVER dense handle for the given CUDA context.
65 /// Call this before invoking other cuSOLVER operations through this wrapper.
66 ///
67 /// cuSOLVER allocates the GPU-side resources it needs here. On the first
68 /// application-defined stream passed to [`Context::set_stream`], cuSOLVER may also
69 /// allocate an internal workspace.
70 ///
71 /// # Errors
72 ///
73 /// Returns an error if the CUDA context cannot be bound, if cuSOLVER cannot
74 /// create a handle, or if cuSOLVER returns a null handle.
75 pub fn create(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
76 cuda_ctx.bind()?;
77
78 let mut handle = ptr::null_mut();
79 unsafe {
80 try_ffi!(sys::cusolverDnCreate(&raw mut handle))?;
81 }
82
83 if handle.is_null() {
84 return Err(Error::NullHandle);
85 }
86
87 Ok(Self {
88 handle: Handle {
89 raw: handle,
90 cuda_ctx: Arc::clone(cuda_ctx),
91 },
92 })
93 }
94
95 /// Returns the underlying CUDA context used by this cuSOLVER handle.
96 pub fn cuda_context(&self) -> &Arc<CudaContext> {
97 &self.handle.cuda_ctx
98 }
99
100 /// Binds the underlying CUDA context associated with this handle.
101 ///
102 /// # Errors
103 ///
104 /// Returns an error if the CUDA context cannot be bound.
105 pub fn bind(&self) -> Result<()> {
106 Ok(self.cuda_context().bind()?)
107 }
108
109 /// Ensures `stream` belongs to the same CUDA context as this handle.
110 ///
111 /// Returns an error if the stream belongs to a different context.
112 pub fn ensure_stream(&self, stream: &Stream) -> Result<()> {
113 if self.cuda_context().as_ref() != stream.context() {
114 return Err(Error::StreamContextMismatch);
115 }
116
117 self.bind()
118 }
119
120 /// Returns the stream currently used by this cuSOLVER handle.
121 ///
122 /// # Errors
123 ///
124 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
125 /// cannot report the current stream.
126 pub fn stream(&self) -> Result<StreamBinding> {
127 self.bind()?;
128
129 let mut stream = ptr::null_mut();
130 unsafe {
131 try_ffi!(sys::cusolverDnGetStream(self.as_raw(), &raw mut stream))?;
132 }
133
134 Ok(if stream.is_null() {
135 StreamBinding::Default
136 } else {
137 StreamBinding::Borrowed(BorrowedStream {
138 handle: stream,
139 cuda_ctx: Arc::clone(self.cuda_context()),
140 })
141 })
142 }
143
144 /// Sets the stream used by this cuSOLVER handle.
145 ///
146 /// Passing `None` restores the CUDA default stream.
147 ///
148 /// # Errors
149 ///
150 /// Returns an error if `stream` belongs to another CUDA context, if the CUDA
151 /// context cannot be bound, or if cuSOLVER rejects the stream.
152 pub fn set_stream(&self, stream: Option<&Stream>) -> Result<()> {
153 if let Some(stream) = stream {
154 self.ensure_stream(stream)?;
155 } else {
156 self.bind()?;
157 }
158
159 unsafe {
160 try_ffi!(sys::cusolverDnSetStream(
161 self.as_raw(),
162 match stream {
163 Some(stream) => stream.as_raw(),
164 None => ptr::null_mut(),
165 },
166 ))?;
167 }
168 Ok(())
169 }
170
171 /// Returns the deterministic mode currently configured on this handle.
172 ///
173 /// # Errors
174 ///
175 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
176 /// cannot report the deterministic mode.
177 pub fn deterministic_mode(&self) -> Result<DeterministicMode> {
178 self.bind()?;
179
180 let mut mode = sys::cusolverDeterministicMode_t::CUSOLVER_DETERMINISTIC_RESULTS;
181 unsafe {
182 try_ffi!(sys::cusolverDnGetDeterministicMode(
183 self.as_raw(),
184 &raw mut mode,
185 ))?;
186 }
187 Ok(mode.into())
188 }
189
190 /// Sets the deterministic mode for operations executed through this handle.
191 ///
192 /// Allowing non-deterministic results may improve performance for some
193 /// operations, including [`xgeqrf`], [`xgesvd`], [`xgesvdr`], and [`xgesvdp`].
194 ///
195 /// # Errors
196 ///
197 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
198 /// rejects the deterministic mode.
199 pub fn set_deterministic_mode(&self, mode: DeterministicMode) -> Result<()> {
200 self.bind()?;
201 unsafe {
202 try_ffi!(sys::cusolverDnSetDeterministicMode(
203 self.as_raw(),
204 mode.into(),
205 ))?;
206 }
207 Ok(())
208 }
209
210 /// Returns the math mode currently configured on this handle.
211 ///
212 /// See [`MathMode`] for the supported wrapper values.
213 ///
214 /// # Errors
215 ///
216 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
217 /// cannot report the math mode.
218 pub fn math_mode(&self) -> Result<MathMode> {
219 self.bind()?;
220
221 let mut mode = sys::cusolverMathMode_t::CUSOLVER_DEFAULT_MATH;
222 unsafe {
223 try_ffi!(sys::cusolverDnGetMathMode(self.as_raw(), &raw mut mode))?;
224 }
225 Ok(mode.into())
226 }
227
228 /// Sets the math mode for operations executed through this handle.
229 ///
230 /// See [`MathMode`] for the supported wrapper values and combinations.
231 ///
232 /// # Errors
233 ///
234 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
235 /// rejects the math mode.
236 pub fn set_math_mode(&self, mode: MathMode) -> Result<()> {
237 self.bind()?;
238 unsafe {
239 try_ffi!(sys::cusolverDnSetMathMode(self.as_raw(), mode.into()))?;
240 }
241 Ok(())
242 }
243
244 /// Returns the emulation strategy configured on this handle.
245 ///
246 /// This only affects operations that use one of the emulated math modes
247 /// described by [`MathMode`].
248 ///
249 /// # Errors
250 ///
251 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
252 /// cannot report the emulation strategy.
253 pub fn emulation_strategy(&self) -> Result<EmulationStrategy> {
254 self.bind()?;
255
256 let mut strategy = EmulationStrategy::Default.into();
257 unsafe {
258 try_ffi!(sys::cusolverDnGetEmulationStrategy(
259 self.as_raw(),
260 &raw mut strategy,
261 ))?;
262 }
263 Ok(strategy.into())
264 }
265
266 /// Sets the emulation strategy for operations executed through this handle.
267 ///
268 /// This only affects operations that use one of the emulated math modes
269 /// described by [`MathMode`].
270 ///
271 /// # Errors
272 ///
273 /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
274 /// rejects the emulation strategy.
275 pub fn set_emulation_strategy(&self, strategy: EmulationStrategy) -> Result<()> {
276 self.bind()?;
277 unsafe {
278 try_ffi!(sys::cusolverDnSetEmulationStrategy(
279 self.as_raw(),
280 strategy.into(),
281 ))?;
282 }
283 Ok(())
284 }
285
286 /// Installs the cuSOLVER logger callback.
287 ///
288 /// # Safety
289 ///
290 /// `callback`, if present, must remain valid for use by cuSOLVER and must
291 /// follow the callback ABI expected by the library.
292 ///
293 /// # Errors
294 ///
295 /// Returns an error if cuSOLVER rejects the callback.
296 pub unsafe fn set_logger_callback(callback: sys::cusolverDnLoggerCallback_t) -> Result<()> {
297 unsafe {
298 try_ffi!(sys::cusolverDnLoggerSetCallback(callback))?;
299 }
300 Ok(())
301 }
302
303 /// Sets the cuSOLVER logger verbosity level.
304 ///
305 /// # Errors
306 ///
307 /// Returns an error if cuSOLVER rejects the logging level.
308 pub fn set_logger_level(level: i32) -> Result<()> {
309 unsafe {
310 try_ffi!(sys::cusolverDnLoggerSetLevel(level))?;
311 }
312 Ok(())
313 }
314
315 /// Sets the cuSOLVER logger mask.
316 ///
317 /// # Errors
318 ///
319 /// Returns an error if cuSOLVER rejects the logging mask.
320 pub fn set_logger_mask(mask: i32) -> Result<()> {
321 unsafe {
322 try_ffi!(sys::cusolverDnLoggerSetMask(mask))?;
323 }
324 Ok(())
325 }
326
327 /// Sets the FILE handle used for cuSOLVER logging.
328 ///
329 /// Once registered, the file handle must remain open until another handle is
330 /// installed or logging is disabled.
331 ///
332 /// # Safety
333 ///
334 /// `file` must be a valid `FILE` handle for as long as cuSOLVER may write to it.
335 ///
336 /// # Errors
337 ///
338 /// Returns an error if cuSOLVER rejects the file handle.
339 pub unsafe fn set_logger_file(file: *mut sys::FILE) -> Result<()> {
340 unsafe {
341 try_ffi!(sys::cusolverDnLoggerSetFile(file))?;
342 }
343 Ok(())
344 }
345
346 /// Sets the cuSOLVER logging output file by path.
347 ///
348 /// # Errors
349 ///
350 /// Returns an error if `path` cannot be converted to a C string or if
351 /// cuSOLVER cannot open the log file.
352 pub fn set_logger_path(path: impl AsRef<Path>) -> Result<()> {
353 let path = path_to_cstring(path.as_ref())?;
354 unsafe {
355 try_ffi!(sys::cusolverDnLoggerOpenFile(path.as_ptr()))?;
356 }
357 Ok(())
358 }
359
360 /// Disables cuSOLVER logging for the current process.
361 ///
362 /// # Errors
363 ///
364 /// Returns an error if cuSOLVER cannot disable logging.
365 pub fn disable_logger() -> Result<()> {
366 unsafe {
367 try_ffi!(sys::cusolverDnLoggerForceDisable())?;
368 }
369 Ok(())
370 }
371
372 /// Returns the raw cuSOLVER dense handle.
373 ///
374 /// The returned handle is borrowed and remains valid only while this
375 /// context and its underlying CUDA context are alive.
376 pub fn as_raw(&self) -> sys::cusolverDnHandle_t {
377 self.handle.raw
378 }
379}
380
381impl Drop for Handle {
382 fn drop(&mut self) {
383 if let Err(err) = self.cuda_ctx.bind() {
384 #[cfg(debug_assertions)]
385 eprintln!("failed to bind cuda context before destroying cusolver handle: {err}");
386 }
387
388 unsafe {
389 if let Err(err) = try_ffi!(sys::cusolverDnDestroy(self.raw)) {
390 #[cfg(debug_assertions)]
391 eprintln!("failed to destroy cusolver context: {err}");
392 }
393 }
394 }
395}