singe_cusolver/irs.rs
1#[allow(unused_imports)]
2use crate::error::Status;
3
4use std::{ptr, slice};
5
6use singe_cuda::{data_type::DataTypeLike, memory::DeviceMemory};
7
8use crate::{
9 context::Context,
10 error::{Error, Result},
11 layout::{MatrixMut, MatrixRef},
12 sys, try_ffi,
13 types::{IrsRefinement, PrecisionType},
14 utility::{to_i32, to_u64},
15};
16
17#[derive(Debug)]
18pub struct IrsParams {
19 handle: sys::cusolverDnIRSParams_t,
20 main_precision: Option<PrecisionType>,
21 lowest_precision: Option<PrecisionType>,
22}
23
24#[derive(Debug, Default)]
25pub struct IrsInfos {
26 handle: sys::cusolverDnIRSInfos_t,
27 residual_history_requested: bool,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct ResidualHistoryEntry<T> {
32 pub total_iterations: T,
33 pub residual_norm: T,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct ResidualHistory<T> {
38 pub rows: Vec<ResidualHistoryEntry<T>>,
39 pub leading_dimension: usize,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct IrsSolve {
44 pub n: usize,
45 pub right_hand_sides: usize,
46}
47
48impl IrsSolve {
49 pub fn new(n: usize, right_hand_sides: usize) -> Self {
50 Self {
51 n,
52 right_hand_sides,
53 }
54 }
55
56 pub fn workspace_size<T: DataTypeLike>(
57 self,
58 ctx: &Context,
59 params: &mut IrsParams,
60 ) -> Result<usize> {
61 xgesv_buffer_size::<T>(ctx, params, self.n, self.right_hand_sides)
62 }
63
64 pub fn execute<T: DataTypeLike>(
65 self,
66 ctx: &Context,
67 params: &mut IrsParams,
68 infos: &IrsInfos,
69 bindings: IrsSolveBindings<'_, T>,
70 ) -> Result<i32> {
71 xgesv(
72 ctx,
73 params,
74 infos,
75 self.n,
76 self.right_hand_sides,
77 bindings.a,
78 bindings.b,
79 bindings.x,
80 bindings.device_workspace,
81 bindings.dev_info,
82 )
83 }
84}
85
86#[derive(Debug)]
87pub struct IrsSolveBindings<'a, T> {
88 pub a: MatrixMut<'a, T>,
89 pub b: MatrixRef<'a, T>,
90 pub x: MatrixMut<'a, T>,
91 pub device_workspace: &'a mut DeviceMemory<u8>,
92 pub dev_info: &'a mut DeviceMemory<i32>,
93}
94
95// IRS parameter/info handles expose mutation through &mut self and inspection
96// through shared references, so immutable sharing follows the cuSOLVER contract.
97unsafe impl Send for IrsParams {}
98unsafe impl Sync for IrsParams {}
99unsafe impl Send for IrsInfos {}
100unsafe impl Sync for IrsInfos {}
101
102impl IrsParams {
103 /// Creates and initializes the parameter structure for IRS solvers such as
104 /// [`xgesv`] and [`xgels`].
105 ///
106 /// The returned parameter structure can be reused across calls to the same
107 /// IRS solver or to different IRS solvers.
108 ///
109 /// In CUDA 10.2, the behavior was different and a new parameter structure
110 /// was required for each IRS solve call.
111 ///
112 /// You can also reconfigure the parameters between solves, but only after
113 /// the previous IRS call has completed.
114 ///
115 /// # Errors
116 ///
117 /// Returns an error if cuSOLVER cannot allocate the required resources
118 /// or does not return a valid handle.
119 pub fn create() -> Result<Self> {
120 let mut handle = ptr::null_mut();
121 unsafe {
122 try_ffi!(sys::cusolverDnIRSParamsCreate(&raw mut handle))?;
123 }
124 if handle.is_null() {
125 return Err(Error::NullHandle);
126 }
127 let mut params = Self {
128 handle,
129 main_precision: None,
130 lowest_precision: None,
131 };
132 params.set_refinement_solver(IrsRefinement::None)?;
133 Ok(params)
134 }
135
136 /// Sets the refinement solver used by IRS operations such as [`xgesv`] and
137 /// [`xgels`].
138 ///
139 /// Configure the refinement algorithm before the first IRS solve. Newly created [`IrsParams`] do not set one by default.
140 ///
141 /// The supported values are described below.
142 ///
143 /// [`IrsRefinement::NotSet`]: Solver is not set. The IRS solver returns an
144 /// error if this value is used.
145 ///
146 /// [`IrsRefinement::None`]: No refinement solver; the IRS solver performs a factorization followed by a solve without any refinement.
147 /// For example, if the IRS solver was [`xgesv`], this is equivalent to an
148 /// [`xgesv`] solve without refinement, with the factorization carried out in
149 /// the lowest configured precision.
150 /// If both the main and lowest precision are [`PrecisionType::R64F`], the
151 /// solve is effectively performed in `f64`.
152 ///
153 /// [`IrsRefinement::Classical`]: Classical iterative refinement solver.
154 /// Similar to the value used in LAPACK operations.
155 ///
156 /// [`IrsRefinement::Gmres`]: GMRES (Generalized Minimal Residual) based iterative refinement solver.
157 /// Recent studies use GMRES as a refinement solver that can outperform
158 /// classical iterative refinement.
159 /// Recommended setting based on cuSOLVER experimentation.
160 ///
161 /// [`IrsRefinement::ClassicalGmres`]: Classical iterative refinement solver that uses the GMRES (Generalized Minimal Residual) internally to solve the correction equation at each iteration.
162 /// The classical refinement iteration is the outer iteration, and GMRES is
163 /// the inner iteration.
164 /// If the tolerance of the inner GMRES is set very low, for
165 /// example near machine precision, then the outer *classical refinement
166 /// iteration* performs only one iteration and this option behaves like
167 /// [`IrsRefinement::Gmres`].
168 ///
169 /// [`IrsRefinement::GmresGmres`]: GMRES-based iterative refinement solver
170 /// that uses another GMRES solve internally for the preconditioned system.
171 ///
172 /// # Errors
173 ///
174 /// Returns an error if cuSOLVER rejects the parameter structure.
175 pub fn set_refinement_solver(&mut self, refinement: IrsRefinement) -> Result<()> {
176 unsafe {
177 try_ffi!(sys::cusolverDnIRSParamsSetRefinementSolver(
178 self.as_raw(),
179 refinement.into(),
180 ))?;
181 }
182 Ok(())
183 }
184
185 /// Sets the main precision for the Iterative Refinement Solver (IRS).
186 ///
187 /// The main precision is the type of the input and output data.
188 /// Configure both the main and lowest precision before the first IRS solve. Those
189 /// values are not inferred when the parameter structure is created because
190 /// they depend on the input/output data type and the requested solver
191 /// configuration. You can set them independently or together with
192 /// [`IrsParams::set_solver_precisions`].
193 ///
194 /// # Errors
195 ///
196 /// Returns an error if cuSOLVER rejects the parameter structure.
197 pub fn set_main_precision(&mut self, precision: PrecisionType) -> Result<()> {
198 unsafe {
199 try_ffi!(sys::cusolverDnIRSParamsSetSolverMainPrecision(
200 self.as_raw(),
201 precision.into(),
202 ))?;
203 }
204 self.main_precision = Some(precision);
205 Ok(())
206 }
207
208 /// Sets the lowest precision that the IRS solver may use.
209 ///
210 /// The lowest precision is the minimum compute precision used
211 /// during the LU factorization process.
212 ///
213 /// Configure both the main and lowest precision before the first IRS solve. They
214 /// are not inferred when creating the parameter structure because they
215 /// depend on the input and output data types and the requested solver
216 /// configuration.
217 /// Usually the lowest precision defines the speedup that can be achieved.
218 /// The ratio between the performance of the lowest precision and the main
219 /// precision gives an approximate upper bound on the speedup.
220 /// More precisely, it depends on many factors, but for large matrices it is
221 /// often tied to the performance ratio of large GEMM-like kernels.
222 /// For instance, if the input/output precision is real double precision
223 /// [`PrecisionType::R64F`] and the lowest precision is
224 /// [`PrecisionType::R32F`], then a speedup of at most about 2x is expected
225 /// for large problem sizes.
226 /// If the lowest precision is [`PrecisionType::R16F`], expect 3x-4x.
227 /// A reasonable strategy accounts for the number of right-hand sides, the matrix size, and the convergence rate.
228 ///
229 /// # Errors
230 ///
231 /// Returns an error if cuSOLVER rejects the parameter structure.
232 pub fn set_lowest_precision(&mut self, precision: PrecisionType) -> Result<()> {
233 unsafe {
234 try_ffi!(sys::cusolverDnIRSParamsSetSolverLowestPrecision(
235 self.as_raw(),
236 precision.into(),
237 ))?;
238 }
239 self.lowest_precision = Some(precision);
240 Ok(())
241 }
242
243 /// Sets both the main and lowest precision for the Iterative Refinement
244 /// Solver (IRS).
245 ///
246 /// The main precision is the precision of the input and output data.
247 /// The lowest precision is the minimum compute precision used
248 /// during the LU factorization process.
249 ///
250 /// Configure both values before the first IRS solve. They are not inferred when
251 /// creating the parameter structure because they depend on the input and
252 /// output data types and the requested solver configuration.
253 ///
254 /// Convenience wrapper around
255 /// [`IrsParams::set_main_precision`] and
256 /// [`IrsParams::set_lowest_precision`].
257 /// All possible combinations of main/lowest precision are described in the table below.
258 /// Usually the lowest precision defines the speedup that can be achieved.
259 /// The ratio between the performance of the lowest precision and the main
260 /// precision gives an approximate upper bound on the speedup.
261 /// More precisely, it depends on many factors, but for large matrices it is
262 /// often tied to the performance ratio of large GEMM-like kernels.
263 /// For instance, if the input/output precision is real double precision
264 /// [`PrecisionType::R64F`] and the lowest precision is
265 /// [`PrecisionType::R32F`], then a speedup of at most about 2x is expected
266 /// for large problem sizes.
267 /// If the lowest precision is [`PrecisionType::R16F`], expect 3x-4x.
268 /// A reasonable strategy accounts for the number of right-hand sides, the matrix size, and the convergence rate.
269 ///
270 /// **Supported input/output data type and lower precision for the IRS solver**
271 ///
272 /// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
273 /// | --- | --- |
274 /// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
275 /// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
276 /// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
277 /// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
278 ///
279 /// # Errors
280 ///
281 /// Returns an error if cuSOLVER rejects the parameter structure.
282 pub fn set_solver_precisions(
283 &mut self,
284 main_precision: PrecisionType,
285 lowest_precision: PrecisionType,
286 ) -> Result<()> {
287 unsafe {
288 try_ffi!(sys::cusolverDnIRSParamsSetSolverPrecisions(
289 self.as_raw(),
290 main_precision.into(),
291 lowest_precision.into(),
292 ))?;
293 }
294 self.main_precision = Some(main_precision);
295 self.lowest_precision = Some(lowest_precision);
296 Ok(())
297 }
298
299 /// Sets the tolerance for the refinement solver.
300 /// By default it is such that all the RHS satisfy:
301 ///
302 /// `RNRM < SQRT(N)*XNRM*ANRM*EPS*BWDMAX` where
303 ///
304 /// * RNRM is the infinity-norm of the residual
305 /// * XNRM is the infinity-norm of the solution
306 /// * ANRM is the infinity-operator-norm of the matrix A
307 /// * EPS is the machine epsilon for the input/output data type that matches
308 /// LAPACK `xLAMCH('Epsilon')`
309 /// * BWDMAX, the value BWDMAX is fixed to 1.0
310 ///
311 /// Use this to set the tolerance to a lower or higher value.
312 /// The tolerance value is always stored in real double precision,
313 /// regardless of the input and output data type.
314 ///
315 /// # Errors
316 ///
317 /// Returns an error if cuSOLVER rejects the parameter structure.
318 pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
319 unsafe {
320 try_ffi!(sys::cusolverDnIRSParamsSetTol(self.as_raw(), tolerance))?;
321 }
322 Ok(())
323 }
324
325 /// Sets the tolerance for the inner refinement solver when
326 /// the refinement solver consists of two levels, for example
327 /// [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`].
328 /// Ignored for one-level refinement solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`].
329 /// The default value is 1e-4.
330 /// This sets the tolerance for the inner solver, such as the inner GMRES.
331 /// For example, if the refinement solver is
332 /// [`IrsRefinement::ClassicalGmres`], setting this tolerance means that the
333 /// inner GMRES solver converges to that tolerance at each outer
334 /// iteration of the classical refinement solver.
335 /// The tolerance value is always stored in real double precision,
336 /// regardless of the input and output data type.
337 ///
338 /// # Errors
339 ///
340 /// Returns an error if cuSOLVER rejects the parameter structure.
341 pub fn set_inner_tolerance(&mut self, tolerance: f64) -> Result<()> {
342 unsafe {
343 try_ffi!(sys::cusolverDnIRSParamsSetTolInner(
344 self.as_raw(),
345 tolerance,
346 ))?;
347 }
348 Ok(())
349 }
350
351 /// Sets the total number of allowed refinement iterations before the solver stops.
352 /// The total is the sum of the outer and inner iterations. Inner iterations are meaningful when a two-level refinement solver is configured.
353 /// The default value is 50.
354 ///
355 /// # Errors
356 ///
357 /// Returns an error if cuSOLVER rejects the parameter structure.
358 pub fn set_max_iterations(&mut self, max_iterations: i32) -> Result<()> {
359 unsafe {
360 try_ffi!(sys::cusolverDnIRSParamsSetMaxIters(
361 self.as_raw(),
362 max_iterations,
363 ))?;
364 }
365 Ok(())
366 }
367
368 /// Sets the maximum number of iterations allowed for the inner refinement solver.
369 /// Ignored for one-level refinement solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`].
370 /// The inner refinement solver stops after reaching either the inner tolerance or `MaxItersInner`.
371 /// The default value is 50.
372 /// Cannot be larger than `MaxIters` because `MaxIters` is the total number of allowed iterations.
373 /// If [`IrsParams::set_max_iterations`] is called after this method, it has priority and overwrites `MaxItersInner` with `min(MaxIters, MaxItersInner)`.
374 ///
375 /// # Errors
376 ///
377 /// Returns an error if `max_iterations` is larger than `MaxIters`, or if
378 /// cuSOLVER rejects the parameter structure.
379 pub fn set_max_inner_iterations(&mut self, max_iterations: i32) -> Result<()> {
380 unsafe {
381 try_ffi!(sys::cusolverDnIRSParamsSetMaxItersInner(
382 self.as_raw(),
383 max_iterations,
384 ))?;
385 }
386 Ok(())
387 }
388
389 /// Returns the current maximum-iteration setting in this parameter structure.
390 /// Current parameter configuration, distinct from [`IrsInfos::max_iterations`], which returns the maximum number of iterations allowed for a particular IRS solver call.
391 /// The parameter structure can be reused across many IRS solver calls.
392 /// The allowed `MaxIters` value can change between calls, while the `Infos` structure contains information about one particular call and cannot be reused for different calls.
393 ///
394 /// # Errors
395 ///
396 /// Returns an error if cuSOLVER rejects the parameter structure.
397 pub fn max_iterations(&self) -> Result<i32> {
398 let mut value = 0;
399 unsafe {
400 try_ffi!(sys::cusolverDnIRSParamsGetMaxIters(
401 self.as_raw(),
402 &raw mut value,
403 ))?;
404 }
405 Ok(value)
406 }
407
408 /// Enables fallback to the main precision if the Iterative Refinement Solver (IRS) fails to converge.
409 /// If the IRS solver fails to converge, it returns a non-convergence code such as `niter < 0`.
410 /// With fallback disabled, it returns the non-convergent solution as-is.
411 /// With fallback enabled, it falls back to the main precision, which is the input/output data precision, and solves the problem again from scratch.
412 /// This fallback is the default behavior.
413 ///
414 /// # Errors
415 ///
416 /// Returns an error if cuSOLVER rejects the parameter structure.
417 pub fn enable_fallback(&mut self) -> Result<()> {
418 unsafe {
419 try_ffi!(sys::cusolverDnIRSParamsEnableFallback(self.as_raw()))?;
420 }
421 Ok(())
422 }
423
424 /// Disables fallback to the main precision if the Iterative Refinement Solver (IRS) fails to converge.
425 /// If the IRS solver fails to converge, it returns a non-convergence code such as `niter < 0`.
426 /// With fallback disabled, the returned solution is whatever the refinement solver reached before returning.
427 /// Disabling fallback does not guarantee that the solution is accurate.
428 /// Re-enable fallback with [`IrsParams::enable_fallback`].
429 ///
430 /// # Errors
431 ///
432 /// Returns an error if cuSOLVER rejects the parameter structure.
433 pub fn disable_fallback(&mut self) -> Result<()> {
434 unsafe {
435 try_ffi!(sys::cusolverDnIRSParamsDisableFallback(self.as_raw()))?;
436 }
437 Ok(())
438 }
439
440 fn ensure_type_precision<T: DataTypeLike>(&mut self) -> Result<()> {
441 let precision = PrecisionType::from_data_type(T::data_type())
442 .ok_or(Error::InvalidPrecisionConfiguration)?;
443 match self.main_precision {
444 Some(existing) if existing != precision => {
445 return Err(Error::InvalidPrecisionConfiguration);
446 }
447 None => self.set_main_precision(precision)?,
448 _ => {}
449 }
450 if self.lowest_precision.is_none() {
451 self.set_lowest_precision(precision)?;
452 }
453 Ok(())
454 }
455
456 pub fn as_raw(&self) -> sys::cusolverDnIRSParams_t {
457 self.handle
458 }
459}
460
461impl Drop for IrsParams {
462 fn drop(&mut self) {
463 unsafe {
464 if let Err(err) = try_ffi!(sys::cusolverDnIRSParamsDestroy(self.handle)) {
465 #[cfg(debug_assertions)]
466 eprintln!("failed to destroy cusolver irs params: {err}");
467 }
468 }
469 }
470}
471
472impl IrsInfos {
473 /// Creates and initializes the `Infos` structure that holds refinement information for an Iterative Refinement Solver (IRS) call.
474 /// Such information includes the total number of iterations needed to converge (`Niters`), the number of outer iterations (meaningful when a two-level preconditioner such as [`IrsRefinement::ClassicalGmres`] is used), the maximum number of iterations allowed for that call, and a pointer to the convergence-history residual norm matrix.
475 /// Construct the `Infos` structure before calling an IRS solver.
476 /// The `Infos` structure is valid for only one call to an IRS solver, since it holds information about that solve; each solve requires its own `Infos` structure.
477 ///
478 /// # Errors
479 ///
480 /// Returns an error if cuSOLVER cannot allocate the required resources
481 /// or does not return a valid handle.
482 pub fn create() -> Result<Self> {
483 let mut handle = ptr::null_mut();
484 unsafe {
485 try_ffi!(sys::cusolverDnIRSInfosCreate(&raw mut handle))?;
486 }
487 if handle.is_null() {
488 return Err(Error::NullHandle);
489 }
490 Ok(Self {
491 handle,
492 residual_history_requested: false,
493 })
494 }
495
496 /// Returns the total number of iterations performed by the IRS solver.
497 /// If this value is negative, the IRS solver did not converge. If fallback to full precision was enabled, the solver fell back to a full-precision solution.
498 /// See [`xgesv`] and [`xgels`] for the meaning of negative `niters` values.
499 ///
500 /// # Errors
501 ///
502 /// Returns an error if cuSOLVER rejects the `Infos` structure.
503 pub fn niters(&self) -> Result<i32> {
504 let mut value = 0;
505 unsafe {
506 try_ffi!(sys::cusolverDnIRSInfosGetNiters(
507 self.as_raw(),
508 &raw mut value,
509 ))?;
510 }
511 Ok(value)
512 }
513
514 /// Returns the number of iterations performed by the outer refinement loop of the IRS solver.
515 /// For one-level solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], this is the same as `Niters`.
516 /// For two-level solvers such as [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], this is the number of outer-loop iterations.
517 /// See [`IrsRefinement`] for refinement mode details.
518 ///
519 /// # Errors
520 ///
521 /// Returns an error if cuSOLVER rejects the `Infos` structure.
522 pub fn outer_niters(&self) -> Result<i32> {
523 let mut value = 0;
524 unsafe {
525 try_ffi!(sys::cusolverDnIRSInfosGetOuterNiters(
526 self.as_raw(),
527 &raw mut value,
528 ))?;
529 }
530 Ok(value)
531 }
532
533 /// Returns the maximum number of iterations allowed for the corresponding IRS solver call.
534 /// Setting used when that call happened, distinct from [`IrsParams::max_iterations`], which returns the current setting in the `params` configuration structure.
535 /// The `params` structure can be reused for many IRS solver calls.
536 /// The allowed `MaxIters` value can change between calls, while this `Infos` structure contains information about one particular call and cannot be reused for different calls.
537 ///
538 /// # Errors
539 ///
540 /// Returns an error if cuSOLVER rejects the `Infos` structure.
541 pub fn max_iterations(&self) -> Result<i32> {
542 let mut value = 0;
543 unsafe {
544 try_ffi!(sys::cusolverDnIRSInfosGetMaxIters(
545 self.as_raw(),
546 &raw mut value,
547 ))?;
548 }
549 Ok(value)
550 }
551
552 /// Asks the IRS solver to store the convergence history
553 /// (residual norms) of the refinement phase so it can later be queried with
554 /// [`IrsInfos::residual_history_f32`] or [`IrsInfos::residual_history_f64`].
555 ///
556 /// # Errors
557 ///
558 /// Returns an error if cuSOLVER rejects the `Infos` structure.
559 pub fn request_residual_history(&mut self) -> Result<()> {
560 unsafe {
561 try_ffi!(sys::cusolverDnIRSInfosRequestResidual(self.as_raw()))?;
562 }
563 self.residual_history_requested = true;
564 Ok(())
565 }
566
567 /// Returns the convergence history stored by the IRS solver when [`IrsInfos::request_residual_history`] was called before solving.
568 /// The residual norm type depends on the input and output precision.
569 /// Double-precision real and complex configurations report `f64` residuals, while single-precision real and complex configurations report `f32` residuals.
570 ///
571 /// The residual history matrix has two columns, even for multiple right-hand sides, and `MaxIters + 1` rows.
572 /// Only the first `OuterNiters + 1` rows contain residual norms; the remaining rows are undefined.
573 /// In the first column, each row `i` contains the total number of iterations performed up to outer iteration `i`.
574 /// In the second column, each row contains the residual norm for that outer iteration.
575 /// Row 0 contains the initial residual before the refinement loop starts, and subsequent rows contain residuals obtained at each outer iteration.
576 /// The history only covers the outer loop.
577 ///
578 /// If the refinement solver was [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], then `OuterNiters == Niters`, and there are `Niters + 1` rows of norms corresponding to the `Niters` outer iterations.
579 ///
580 /// If the refinement solver was [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], then `OuterNiters <= Niters` corresponds to the outer iterations performed by the outer refinement loop.
581 /// There are `OuterNiters + 1` residual norms. Row `i` corresponds to outer iteration `i`; the first column contains the total number of outer and inner iterations performed up to that step, and the second column contains the residual norm at that step.
582 ///
583 /// For example, if [`IrsRefinement::ClassicalGmres`] needs 3 outer iterations to converge and 4, 3, and 3 inner iterations at each outer iteration, it performs 10 total iterations.
584 /// Row 0 corresponds to the first residual before the refinement start, so it has 0 in its first column.
585 /// Row 1 corresponds to outer iteration 1 and contains 4 in its first column, row 2 contains 7, and row 3 contains 10.
586 ///
587 /// In summary, let `ldh = MaxIters + 1`, the leading dimension of the residual matrix. Then `residual_history[i]` contains the total number of iterations performed at outer iteration `i`, and `residual_history[i + ldh]` contains the residual norm at that outer iteration.
588 ///
589 /// # Errors
590 ///
591 /// Returns an error if residual history was not requested before solving,
592 /// or if cuSOLVER rejects the `Infos` structure.
593 pub fn residual_history_f32(&self) -> Result<ResidualHistory<f32>> {
594 if !self.residual_history_requested {
595 return Err(Error::InvalidPrecisionConfiguration);
596 }
597 let (leading_dimension, valid_rows) = self.residual_history_layout()?;
598 let mut history = ptr::null_mut();
599 unsafe {
600 try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
601 self.as_raw(),
602 &raw mut history,
603 ))?;
604 Ok(copy_residual_history(
605 history.cast::<f32>(),
606 leading_dimension,
607 valid_rows,
608 ))
609 }
610 }
611
612 /// Returns the convergence history stored by the IRS solver when [`IrsInfos::request_residual_history`] was called before solving.
613 /// The residual norm type depends on the input and output precision.
614 /// Double-precision real and complex configurations report `f64` residuals, while single-precision real and complex configurations report `f32` residuals.
615 ///
616 /// The residual history matrix has two columns, even for multiple right-hand sides, and `MaxIters + 1` rows.
617 /// Only the first `OuterNiters + 1` rows contain residual norms; the remaining rows are undefined.
618 /// In the first column, each row `i` contains the total number of iterations performed up to outer iteration `i`.
619 /// In the second column, each row contains the residual norm for that outer iteration.
620 /// Row 0 contains the initial residual before the refinement loop starts, and subsequent rows contain residuals obtained at each outer iteration.
621 /// The history only covers the outer loop.
622 ///
623 /// If the refinement solver was [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], then `OuterNiters == Niters`, and there are `Niters + 1` rows of norms corresponding to the `Niters` outer iterations.
624 ///
625 /// If the refinement solver was [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], then `OuterNiters <= Niters` corresponds to the outer iterations performed by the outer refinement loop.
626 /// There are `OuterNiters + 1` residual norms. Row `i` corresponds to outer iteration `i`; the first column contains the total number of outer and inner iterations performed up to that step, and the second column contains the residual norm at that step.
627 ///
628 /// For example, if [`IrsRefinement::ClassicalGmres`] needs 3 outer iterations to converge and 4, 3, and 3 inner iterations at each outer iteration, it performs 10 total iterations.
629 /// Row 0 corresponds to the first residual before the refinement start, so it has 0 in its first column.
630 /// Row 1 corresponds to outer iteration 1 and contains 4 in its first column, row 2 contains 7, and row 3 contains 10.
631 ///
632 /// In summary, let `ldh = MaxIters + 1`, the leading dimension of the residual matrix. Then `residual_history[i]` contains the total number of iterations performed at outer iteration `i`, and `residual_history[i + ldh]` contains the residual norm at that outer iteration.
633 ///
634 /// # Errors
635 ///
636 /// Returns an error if residual history was not requested before solving,
637 /// or if cuSOLVER rejects the `Infos` structure.
638 pub fn residual_history_f64(&self) -> Result<ResidualHistory<f64>> {
639 if !self.residual_history_requested {
640 return Err(Error::InvalidPrecisionConfiguration);
641 }
642 let (leading_dimension, valid_rows) = self.residual_history_layout()?;
643 let mut history = ptr::null_mut();
644 unsafe {
645 try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
646 self.as_raw(),
647 &raw mut history,
648 ))?;
649 Ok(copy_residual_history(
650 history.cast::<f64>(),
651 leading_dimension,
652 valid_rows,
653 ))
654 }
655 }
656
657 pub fn as_raw(&self) -> sys::cusolverDnIRSInfos_t {
658 self.handle
659 }
660
661 fn residual_history_layout(&self) -> Result<(usize, usize)> {
662 let leading_dimension = self
663 .max_iterations()?
664 .checked_add(1)
665 .ok_or(Error::InvalidResidualHistory)
666 .and_then(|value| {
667 usize::try_from(value).map_err(|_| Error::OutOfRange {
668 name: "residual history leading dimension".into(),
669 })
670 })?;
671 let valid_rows = self
672 .outer_niters()?
673 .checked_add(1)
674 .ok_or(Error::InvalidResidualHistory)
675 .and_then(|value| {
676 usize::try_from(value).map_err(|_| Error::OutOfRange {
677 name: "residual history rows".into(),
678 })
679 })?;
680
681 if valid_rows > leading_dimension {
682 return Err(Error::InvalidResidualHistory);
683 }
684
685 Ok((leading_dimension, valid_rows))
686 }
687}
688
689impl Drop for IrsInfos {
690 fn drop(&mut self) {
691 unsafe {
692 if let Err(err) = try_ffi!(sys::cusolverDnIRSInfosDestroy(self.handle)) {
693 #[cfg(debug_assertions)]
694 eprintln!("failed to destroy cusolver irs infos: {err}");
695 }
696 }
697 }
698}
699
700pub fn xgesv_buffer_size<T: DataTypeLike>(
701 ctx: &Context,
702 params: &mut IrsParams,
703 n: usize,
704 nrhs: usize,
705) -> Result<usize> {
706 ctx.bind()?;
707 if n == 0 || nrhs == 0 {
708 return Err(Error::InvalidMatrixShape);
709 }
710 params.ensure_type_precision::<T>()?;
711 let mut workspace_bytes = 0;
712 unsafe {
713 try_ffi!(sys::cusolverDnIRSXgesv_bufferSize(
714 ctx.as_raw(),
715 params.as_raw(),
716 to_i32(n, "n")?,
717 to_i32(nrhs, "nrhs")?,
718 &raw mut workspace_bytes,
719 ))?;
720 }
721 Ok(workspace_bytes as usize)
722}
723
724/// Provides the same solve as the typed cuSOLVER `gesv` entry
725/// points, but through a generic Rust wrapper that exposes IRS configuration
726/// and reporting more directly.
727/// [`xgesv`] allows additional control of the solver parameters such as setting:
728///
729/// * the main precision (input/output precision) of the solver
730/// * the lowest precision to be used internally by the solver
731/// * the refinement solver type
732/// * the maximum allowed number of iterations in the refinement phase
733/// * the tolerance of the refinement solver
734/// * the fallback to main precision
735/// * and more
736///
737/// through [`IrsParams`] and its helper methods.
738/// Moreover, [`xgesv`] provides additional output information such as the convergence history (for example, residual norms) at each iteration and the number of iterations needed to converge.
739/// [`IrsInfos`] exposes the information reported for a particular solve.
740///
741/// The returned value describes the solving process.
742/// `Ok` indicates that the solve finished successfully. An error indicates that one of the arguments is incorrect, that the parameter or info structures are misconfigured, or that the solve did not finish successfully.
743/// Check `niters` and `dinfo` for additional error details.
744/// Provide the required device workspace through `workspace`.
745/// Query the required byte count with [`xgesv_buffer_size`].
746/// Apply any required configuration through the parameter structure before calling [`xgesv_buffer_size`] so the workspace size matches that configuration.
747///
748/// Tensor Float (TF32), introduced with NVIDIA Ampere architecture GPUs, is the most robust tensor core accelerated compute mode for the iterative refinement solver.
749/// It solves a broad range of HPC problems and can provide up to 4x and 5x
750/// speedups for real and complex systems, respectively.
751/// On Volta and Turing architecture GPUs, half precision tensor core acceleration is recommended.
752/// In cases where the iterative refinement solver fails to converge to the desired accuracy (main precision, input/output data precision), it is recommended to use main precision as internal lowest precision.
753///
754/// The following table provides all possible lowest-precision values corresponding to the input/output data type.
755/// If the lowest precision matches the input/output data type, the main
756/// precision factorization is used.
757///
758/// **Supported input/output data type and lower precision for the IRS solver**
759///
760/// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
761/// | --- | --- |
762/// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
763/// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
764/// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
765/// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
766///
767/// [`xgesv_buffer_size`] returns the required workspace size in bytes for the
768/// current [`IrsParams`] configuration.
769///
770/// # Errors
771///
772/// Returns an error if cuSOLVER rejects the matrix dimensions, leading
773/// dimensions, parameter structure, info structure, or workspace. The workspace
774/// can become invalid if [`xgesv_buffer_size`] is called and then an IRS
775/// configuration value, such as the lowest precision, is changed. cuSOLVER can
776/// also report an error if host memory allocation fails, if the selected IRS
777/// configuration is not supported on the current GPU architecture, if the
778/// library has not been initialized, or if the solve ends with an internal or
779/// numerical failure. Check `niters` and `dinfo` for additional solver details.
780pub fn xgesv<T: DataTypeLike>(
781 ctx: &Context,
782 params: &mut IrsParams,
783 infos: &IrsInfos,
784 n: usize,
785 nrhs: usize,
786 a: MatrixMut<'_, T>,
787 b: MatrixRef<'_, T>,
788 x: MatrixMut<'_, T>,
789 device_workspace: &mut DeviceMemory<u8>,
790 dev_info: &mut DeviceMemory<i32>,
791) -> Result<i32> {
792 ctx.bind()?;
793 validate_matrix(n, n, a.data.len(), a.leading_dimension)?;
794 validate_matrix(n, nrhs, b.data.len(), b.leading_dimension)?;
795 validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
796 require_info_buffer(dev_info)?;
797 let workspace_bytes = xgesv_buffer_size::<T>(ctx, params, n, nrhs)?;
798 require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
799 let mut niters = 0;
800 unsafe {
801 try_ffi!(sys::cusolverDnIRSXgesv(
802 ctx.as_raw(),
803 params.as_raw(),
804 infos.as_raw(),
805 to_i32(n, "n")?,
806 to_i32(nrhs, "nrhs")?,
807 a.data.as_mut_ptr() as _,
808 to_i32(a.leading_dimension, "ldda")?,
809 b.data.as_ptr() as _,
810 to_i32(b.leading_dimension, "lddb")?,
811 x.data.as_mut_ptr() as _,
812 to_i32(x.leading_dimension, "lddx")?,
813 device_workspace.as_mut_ptr() as _,
814 to_u64(workspace_bytes, "lwork_bytes")?,
815 &raw mut niters,
816 dev_info.as_mut_ptr() as _,
817 ))?;
818 }
819 Ok(niters)
820}
821
822pub fn xgels_buffer_size<T: DataTypeLike>(
823 ctx: &Context,
824 params: &mut IrsParams,
825 m: usize,
826 n: usize,
827 nrhs: usize,
828) -> Result<usize> {
829 ctx.bind()?;
830 if m == 0 || n == 0 || nrhs == 0 || n > m {
831 return Err(Error::InvalidMatrixShape);
832 }
833 params.ensure_type_precision::<T>()?;
834 let mut workspace_bytes = 0;
835 unsafe {
836 try_ffi!(sys::cusolverDnIRSXgels_bufferSize(
837 ctx.as_raw(),
838 params.as_raw(),
839 to_i32(m, "m")?,
840 to_i32(n, "n")?,
841 to_i32(nrhs, "nrhs")?,
842 &raw mut workspace_bytes,
843 ))?;
844 }
845 Ok(workspace_bytes as usize)
846}
847
848/// Provides the same solve as the typed cuSOLVER `gels` entry
849/// points, but through a generic Rust wrapper that exposes IRS configuration
850/// and reporting more directly.
851/// [`xgels`] allows additional control of the solver parameters such as setting:
852///
853/// * the main precision (input/output precision) of the solver,
854/// * the lowest precision to be used internally by the solver,
855/// * the refinement solver type
856/// * the maximum allowed number of iterations in the refinement phase
857/// * the tolerance of the refinement solver
858/// * the fallback to main precision
859/// * and others
860///
861/// through [`IrsParams`] and its helper methods.
862/// Moreover, [`xgels`] provides additional output information such as the convergence history (for example, residual norms) at each iteration and the number of iterations needed to converge.
863/// [`IrsInfos`] exposes the information reported for a particular solve.
864///
865/// The returned value describes the solving process.
866/// `Ok` indicates that the solve finished successfully. An error indicates that one of the arguments is incorrect, that the parameter or info structures are misconfigured, or that the solve did not finish successfully.
867/// Check `niters` and `dinfo` for additional error details.
868/// Provide the required device workspace through `workspace`.
869/// Query the required byte count with [`xgels_buffer_size`].
870/// Apply any required configuration through the parameter structure before calling [`xgels_buffer_size`] so the workspace size matches that configuration.
871///
872/// The following table provides all possible lowest-precision values corresponding to the input/output data type.
873/// If the lowest precision matches the input/output data type, the main
874/// precision factorization is used.
875///
876/// Tensor Float (TF32), introduced with NVIDIA Ampere architecture GPUs, is the most robust tensor core accelerated compute mode for the iterative refinement solver.
877/// It solves a broad range of HPC problems and can provide up to 4x and 5x
878/// speedups for real and complex systems, respectively.
879/// On Volta and Turing architecture GPUs, half precision tensor core acceleration is recommended.
880/// In cases where the iterative refinement solver fails to converge to the desired accuracy (main precision, input/output data precision), it is recommended to use main precision as internal lowest precision.
881///
882/// **Supported input/output data type and lower precision for the IRS solver**
883///
884/// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
885/// | --- | --- |
886/// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
887/// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
888/// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
889/// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
890///
891/// [`xgels_buffer_size`] returns the required workspace size in bytes for the
892/// current [`IrsParams`] configuration.
893///
894/// # Errors
895///
896/// Returns an error if cuSOLVER rejects the matrix dimensions, leading
897/// dimensions, parameter structure, info structure, or workspace. The workspace
898/// can become invalid if [`xgels_buffer_size`] is called and then an IRS
899/// configuration value, such as the lowest precision, is changed. cuSOLVER can
900/// also report an error if host memory allocation fails, if the selected IRS
901/// configuration is not supported on the current GPU architecture, if the
902/// library has not been initialized, or if the solve ends with an internal or
903/// numerical failure. Check `niters` and `dinfo` for additional solver details.
904pub fn xgels<T: DataTypeLike>(
905 ctx: &Context,
906 params: &mut IrsParams,
907 infos: &IrsInfos,
908 m: usize,
909 n: usize,
910 nrhs: usize,
911 a: MatrixMut<'_, T>,
912 b: MatrixRef<'_, T>,
913 x: MatrixMut<'_, T>,
914 device_workspace: &mut DeviceMemory<u8>,
915 dev_info: &mut DeviceMemory<i32>,
916) -> Result<i32> {
917 ctx.bind()?;
918 if n > m {
919 return Err(Error::InvalidMatrixShape);
920 }
921 validate_matrix(m, n, a.data.len(), a.leading_dimension)?;
922 validate_matrix(m, nrhs, b.data.len(), b.leading_dimension)?;
923 validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
924 require_info_buffer(dev_info)?;
925 let workspace_bytes = xgels_buffer_size::<T>(ctx, params, m, n, nrhs)?;
926 require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
927 let mut niters = 0;
928 unsafe {
929 try_ffi!(sys::cusolverDnIRSXgels(
930 ctx.as_raw(),
931 params.as_raw(),
932 infos.as_raw(),
933 to_i32(m, "m")?,
934 to_i32(n, "n")?,
935 to_i32(nrhs, "nrhs")?,
936 a.data.as_mut_ptr() as _,
937 to_i32(a.leading_dimension, "ldda")?,
938 b.data.as_ptr() as _,
939 to_i32(b.leading_dimension, "lddb")?,
940 x.data.as_mut_ptr() as _,
941 to_i32(x.leading_dimension, "lddx")?,
942 device_workspace.as_mut_ptr() as _,
943 to_u64(workspace_bytes, "lwork_bytes")?,
944 &raw mut niters,
945 dev_info.as_mut_ptr() as _,
946 ))?;
947 }
948 Ok(niters)
949}
950
951fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
952 if dev_info.is_empty() {
953 return Err(Error::InvalidVectorShape);
954 }
955 Ok(())
956}
957
958fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
959 if actual < required {
960 return Err(Error::InsufficientWorkspaceSize { required, actual });
961 }
962 Ok(())
963}
964
965unsafe fn copy_residual_history<T: Copy>(
966 history: *const T,
967 leading_dimension: usize,
968 valid_rows: usize,
969) -> ResidualHistory<T> {
970 let history = unsafe { slice::from_raw_parts(history, leading_dimension.saturating_mul(2)) };
971 let mut rows = Vec::with_capacity(valid_rows);
972 for row in 0..valid_rows {
973 rows.push(ResidualHistoryEntry {
974 total_iterations: history[row],
975 residual_norm: history[row + leading_dimension],
976 });
977 }
978 ResidualHistory {
979 rows,
980 leading_dimension,
981 }
982}
983
984fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
985 if rows == 0 || cols == 0 {
986 return Err(Error::InvalidMatrixShape);
987 }
988 if lda < rows {
989 return Err(Error::InvalidLeadingDimension);
990 }
991 let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
992 if len < required {
993 return Err(Error::InvalidMatrixShape);
994 }
995 Ok(())
996}
997
998#[cfg(all(test, feature = "testing"))]
999mod tests {
1000 use singe_cuda::memory::DeviceMemory;
1001
1002 use super::*;
1003 use crate::testing::setup_context_if_available;
1004
1005 #[test]
1006 fn test_xgesv_solves_diagonal_system() -> Result<()> {
1007 let Some(ctx) = setup_context_if_available()? else {
1008 return Ok(());
1009 };
1010 let mut params = IrsParams::create()?;
1011 let infos = IrsInfos::create()?;
1012
1013 let mut a = DeviceMemory::from_slice(&[
1014 2.0_f32, 0.0, //
1015 0.0, 4.0,
1016 ])?;
1017 let b = DeviceMemory::from_slice(&[
1018 6.0_f32, //
1019 8.0,
1020 ])?;
1021 let mut x = DeviceMemory::create(2)?;
1022 let workspace_bytes = xgesv_buffer_size::<f32>(&ctx, &mut params, 2, 1)?;
1023 let mut workspace = DeviceMemory::create(workspace_bytes.max(1))?;
1024 let mut dev_info = DeviceMemory::create(1)?;
1025
1026 let _ = xgesv(
1027 &ctx,
1028 &mut params,
1029 &infos,
1030 2,
1031 1,
1032 MatrixMut::new(&mut a, 2),
1033 MatrixRef::new(&b, 2),
1034 MatrixMut::new(&mut x, 2),
1035 &mut workspace,
1036 &mut dev_info,
1037 )?;
1038
1039 assert_eq!(dev_info.copy_to_host_vec()?, vec![0]);
1040 assert_eq!(x.copy_to_host_vec()?, vec![3.0, 2.0]);
1041 Ok(())
1042 }
1043}