Skip to main content

singe_cusolver/
context.rs

1#[allow(unused_imports)]
2use crate::{dense::*, svd::*};
3
4use std::{path::Path, ptr, sync::Arc};
5
6use singe_core::path_to_cstring;
7use singe_cuda::{context::Context as CudaContext, stream::Stream, types::EmulationStrategy};
8use singe_cuda_sys::runtime;
9
10use crate::{
11    error::{Error, Result},
12    sys, try_ffi,
13    types::{DeterministicMode, MathMode},
14};
15
16/// A stateful cuSOLVER handle.
17///
18/// Use one context per host thread or concurrent task. The handle is movable
19/// between threads, but it is intentionally not `Clone` or `Sync`.
20#[derive(Debug)]
21pub struct Context {
22    handle: Handle,
23}
24
25#[derive(Debug)]
26struct Handle {
27    raw: sys::cusolverDnHandle_t,
28    cuda_ctx: Arc<CudaContext>,
29}
30
31// cuSOLVER handles are stateful and stream-bound. The owner may move between
32// threads, but callers need exclusive access to mutate handle state.
33unsafe impl Send for Handle {}
34
35/// The stream bound to a cuSOLVER handle.
36#[derive(Debug, Clone)]
37pub enum StreamBinding {
38    /// The CUDA default stream.
39    Default,
40    /// A borrowed stream associated with the same CUDA context.
41    Borrowed(BorrowedStream),
42}
43
44/// A stream borrowed from a CUDA context, associated with a cuSOLVER handle.
45#[derive(Debug, Clone)]
46pub struct BorrowedStream {
47    handle: runtime::cudaStream_t,
48    cuda_ctx: Arc<CudaContext>,
49}
50
51impl BorrowedStream {
52    /// Returns the raw CUDA stream handle.
53    pub const fn as_raw(&self) -> runtime::cudaStream_t {
54        self.handle
55    }
56
57    /// Returns a reference to the CUDA context this stream belongs to.
58    pub fn context(&self) -> &CudaContext {
59        self.cuda_ctx.as_ref()
60    }
61}
62
63impl Context {
64    /// Creates a cuSOLVER dense handle for the given CUDA context.
65    /// Call this before invoking other cuSOLVER operations through this wrapper.
66    ///
67    /// cuSOLVER allocates the GPU-side resources it needs here. On the first
68    /// application-defined stream passed to [`Context::set_stream`], cuSOLVER may also
69    /// allocate an internal workspace.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if the CUDA context cannot be bound, if cuSOLVER cannot
74    /// create a handle, or if cuSOLVER returns a null handle.
75    pub fn create(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
76        cuda_ctx.bind()?;
77
78        let mut handle = ptr::null_mut();
79        unsafe {
80            try_ffi!(sys::cusolverDnCreate(&raw mut handle))?;
81        }
82
83        if handle.is_null() {
84            return Err(Error::NullHandle);
85        }
86
87        Ok(Self {
88            handle: Handle {
89                raw: handle,
90                cuda_ctx: Arc::clone(cuda_ctx),
91            },
92        })
93    }
94
95    /// Returns the underlying CUDA context used by this cuSOLVER handle.
96    pub fn cuda_context(&self) -> &Arc<CudaContext> {
97        &self.handle.cuda_ctx
98    }
99
100    /// Binds the underlying CUDA context associated with this handle.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the CUDA context cannot be bound.
105    pub fn bind(&self) -> Result<()> {
106        Ok(self.cuda_context().bind()?)
107    }
108
109    /// Ensures `stream` belongs to the same CUDA context as this handle.
110    ///
111    /// Returns an error if the stream belongs to a different context.
112    pub fn ensure_stream(&self, stream: &Stream) -> Result<()> {
113        if self.cuda_context().as_ref() != stream.context() {
114            return Err(Error::StreamContextMismatch);
115        }
116
117        self.bind()
118    }
119
120    /// Returns the stream currently used by this cuSOLVER handle.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
125    /// cannot report the current stream.
126    pub fn stream(&self) -> Result<StreamBinding> {
127        self.bind()?;
128
129        let mut stream = ptr::null_mut();
130        unsafe {
131            try_ffi!(sys::cusolverDnGetStream(self.as_raw(), &raw mut stream))?;
132        }
133
134        Ok(if stream.is_null() {
135            StreamBinding::Default
136        } else {
137            StreamBinding::Borrowed(BorrowedStream {
138                handle: stream,
139                cuda_ctx: Arc::clone(self.cuda_context()),
140            })
141        })
142    }
143
144    /// Sets the stream used by this cuSOLVER handle.
145    ///
146    /// Passing `None` restores the CUDA default stream.
147    ///
148    /// # Errors
149    ///
150    /// Returns an error if `stream` belongs to another CUDA context, if the CUDA
151    /// context cannot be bound, or if cuSOLVER rejects the stream.
152    pub fn set_stream(&self, stream: Option<&Stream>) -> Result<()> {
153        if let Some(stream) = stream {
154            self.ensure_stream(stream)?;
155        } else {
156            self.bind()?;
157        }
158
159        unsafe {
160            try_ffi!(sys::cusolverDnSetStream(
161                self.as_raw(),
162                match stream {
163                    Some(stream) => stream.as_raw(),
164                    None => ptr::null_mut(),
165                },
166            ))?;
167        }
168        Ok(())
169    }
170
171    /// Returns the deterministic mode currently configured on this handle.
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
176    /// cannot report the deterministic mode.
177    pub fn deterministic_mode(&self) -> Result<DeterministicMode> {
178        self.bind()?;
179
180        let mut mode = sys::cusolverDeterministicMode_t::CUSOLVER_DETERMINISTIC_RESULTS;
181        unsafe {
182            try_ffi!(sys::cusolverDnGetDeterministicMode(
183                self.as_raw(),
184                &raw mut mode,
185            ))?;
186        }
187        Ok(mode.into())
188    }
189
190    /// Sets the deterministic mode for operations executed through this handle.
191    ///
192    /// Allowing non-deterministic results may improve performance for some
193    /// operations, including [`xgeqrf`], [`xgesvd`], [`xgesvdr`], and [`xgesvdp`].
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
198    /// rejects the deterministic mode.
199    pub fn set_deterministic_mode(&self, mode: DeterministicMode) -> Result<()> {
200        self.bind()?;
201        unsafe {
202            try_ffi!(sys::cusolverDnSetDeterministicMode(
203                self.as_raw(),
204                mode.into(),
205            ))?;
206        }
207        Ok(())
208    }
209
210    /// Returns the math mode currently configured on this handle.
211    ///
212    /// See [`MathMode`] for the supported wrapper values.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
217    /// cannot report the math mode.
218    pub fn math_mode(&self) -> Result<MathMode> {
219        self.bind()?;
220
221        let mut mode = sys::cusolverMathMode_t::CUSOLVER_DEFAULT_MATH;
222        unsafe {
223            try_ffi!(sys::cusolverDnGetMathMode(self.as_raw(), &raw mut mode))?;
224        }
225        Ok(mode.into())
226    }
227
228    /// Sets the math mode for operations executed through this handle.
229    ///
230    /// See [`MathMode`] for the supported wrapper values and combinations.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
235    /// rejects the math mode.
236    pub fn set_math_mode(&self, mode: MathMode) -> Result<()> {
237        self.bind()?;
238        unsafe {
239            try_ffi!(sys::cusolverDnSetMathMode(self.as_raw(), mode.into()))?;
240        }
241        Ok(())
242    }
243
244    /// Returns the emulation strategy configured on this handle.
245    ///
246    /// This only affects operations that use one of the emulated math modes
247    /// described by [`MathMode`].
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
252    /// cannot report the emulation strategy.
253    pub fn emulation_strategy(&self) -> Result<EmulationStrategy> {
254        self.bind()?;
255
256        let mut strategy = EmulationStrategy::Default.into();
257        unsafe {
258            try_ffi!(sys::cusolverDnGetEmulationStrategy(
259                self.as_raw(),
260                &raw mut strategy,
261            ))?;
262        }
263        Ok(strategy.into())
264    }
265
266    /// Sets the emulation strategy for operations executed through this handle.
267    ///
268    /// This only affects operations that use one of the emulated math modes
269    /// described by [`MathMode`].
270    ///
271    /// # Errors
272    ///
273    /// Returns an error if the CUDA context cannot be bound or if cuSOLVER
274    /// rejects the emulation strategy.
275    pub fn set_emulation_strategy(&self, strategy: EmulationStrategy) -> Result<()> {
276        self.bind()?;
277        unsafe {
278            try_ffi!(sys::cusolverDnSetEmulationStrategy(
279                self.as_raw(),
280                strategy.into(),
281            ))?;
282        }
283        Ok(())
284    }
285
286    /// Installs the cuSOLVER logger callback.
287    ///
288    /// # Safety
289    ///
290    /// `callback`, if present, must remain valid for use by cuSOLVER and must
291    /// follow the callback ABI expected by the library.
292    ///
293    /// # Errors
294    ///
295    /// Returns an error if cuSOLVER rejects the callback.
296    pub unsafe fn set_logger_callback(callback: sys::cusolverDnLoggerCallback_t) -> Result<()> {
297        unsafe {
298            try_ffi!(sys::cusolverDnLoggerSetCallback(callback))?;
299        }
300        Ok(())
301    }
302
303    /// Sets the cuSOLVER logger verbosity level.
304    ///
305    /// # Errors
306    ///
307    /// Returns an error if cuSOLVER rejects the logging level.
308    pub fn set_logger_level(level: i32) -> Result<()> {
309        unsafe {
310            try_ffi!(sys::cusolverDnLoggerSetLevel(level))?;
311        }
312        Ok(())
313    }
314
315    /// Sets the cuSOLVER logger mask.
316    ///
317    /// # Errors
318    ///
319    /// Returns an error if cuSOLVER rejects the logging mask.
320    pub fn set_logger_mask(mask: i32) -> Result<()> {
321        unsafe {
322            try_ffi!(sys::cusolverDnLoggerSetMask(mask))?;
323        }
324        Ok(())
325    }
326
327    /// Sets the FILE handle used for cuSOLVER logging.
328    ///
329    /// Once registered, the file handle must remain open until another handle is
330    /// installed or logging is disabled.
331    ///
332    /// # Safety
333    ///
334    /// `file` must be a valid `FILE` handle for as long as cuSOLVER may write to it.
335    ///
336    /// # Errors
337    ///
338    /// Returns an error if cuSOLVER rejects the file handle.
339    pub unsafe fn set_logger_file(file: *mut sys::FILE) -> Result<()> {
340        unsafe {
341            try_ffi!(sys::cusolverDnLoggerSetFile(file))?;
342        }
343        Ok(())
344    }
345
346    /// Sets the cuSOLVER logging output file by path.
347    ///
348    /// # Errors
349    ///
350    /// Returns an error if `path` cannot be converted to a C string or if
351    /// cuSOLVER cannot open the log file.
352    pub fn set_logger_path(path: impl AsRef<Path>) -> Result<()> {
353        let path = path_to_cstring(path.as_ref())?;
354        unsafe {
355            try_ffi!(sys::cusolverDnLoggerOpenFile(path.as_ptr()))?;
356        }
357        Ok(())
358    }
359
360    /// Disables cuSOLVER logging for the current process.
361    ///
362    /// # Errors
363    ///
364    /// Returns an error if cuSOLVER cannot disable logging.
365    pub fn disable_logger() -> Result<()> {
366        unsafe {
367            try_ffi!(sys::cusolverDnLoggerForceDisable())?;
368        }
369        Ok(())
370    }
371
372    /// Returns the raw cuSOLVER dense handle.
373    ///
374    /// The returned handle is borrowed and remains valid only while this
375    /// context and its underlying CUDA context are alive.
376    pub fn as_raw(&self) -> sys::cusolverDnHandle_t {
377        self.handle.raw
378    }
379}
380
381impl Drop for Handle {
382    fn drop(&mut self) {
383        if let Err(err) = self.cuda_ctx.bind() {
384            #[cfg(debug_assertions)]
385            eprintln!("failed to bind cuda context before destroying cusolver handle: {err}");
386        }
387
388        unsafe {
389            if let Err(err) = try_ffi!(sys::cusolverDnDestroy(self.raw)) {
390                #[cfg(debug_assertions)]
391                eprintln!("failed to destroy cusolver context: {err}");
392            }
393        }
394    }
395}