1#[allow(unused_imports)]
2use crate::{eigen::xsyevd, error::Status};
3
4use std::ptr;
5
6use singe_cuda::{
7 data_type::{DataType, DataTypeLike},
8 memory::DeviceMemory,
9 types::{Complex32, Complex64},
10};
11
12use crate::{
13 context::Context,
14 error::{Error, Result},
15 layout::{
16 ByteWorkspaceMut, MatrixMut, MatrixRef, StridedBatchedMatrixMut, StridedBatchedMatrixRef,
17 StridedBatchedVectorMut, StridedBatchedVectorRef, WorkspaceSizes,
18 },
19 params::Params,
20 sys, try_ffi,
21 types::{EigenMode, SvdMode, TruncatedSvdMode},
22 utility::{to_i32, to_i64, to_usize},
23};
24
25#[derive(Debug)]
26pub struct GesvdjInfo {
27 handle: sys::gesvdjInfo_t,
28}
29
30unsafe impl Send for GesvdjInfo {}
33unsafe impl Sync for GesvdjInfo {}
34
35impl GesvdjInfo {
36 pub fn create() -> Result<Self> {
43 let mut handle = ptr::null_mut();
44 unsafe {
45 try_ffi!(sys::cusolverDnCreateGesvdjInfo(&raw mut handle))?;
46 }
47
48 if handle.is_null() {
49 return Err(Error::NullHandle);
50 }
51
52 Ok(Self { handle })
53 }
54
55 pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
61 unsafe {
62 try_ffi!(sys::cusolverDnXgesvdjSetTolerance(self.as_raw(), tolerance,))?;
63 }
64 Ok(())
65 }
66
67 pub fn set_max_sweeps(&mut self, max_sweeps: i32) -> Result<()> {
74 unsafe {
75 try_ffi!(sys::cusolverDnXgesvdjSetMaxSweeps(
76 self.as_raw(),
77 max_sweeps,
78 ))?;
79 }
80 Ok(())
81 }
82
83 pub fn set_sort_eigenvalues(&mut self, sort_eigenvalues: bool) -> Result<()> {
92 unsafe {
93 try_ffi!(sys::cusolverDnXgesvdjSetSortEig(
94 self.as_raw(),
95 i32::from(sort_eigenvalues),
96 ))?;
97 }
98 Ok(())
99 }
100
101 pub fn residual(&self, ctx: &Context) -> Result<f64> {
112 ctx.bind()?;
113
114 let mut residual = 0.0;
115 unsafe {
116 try_ffi!(sys::cusolverDnXgesvdjGetResidual(
117 ctx.as_raw(),
118 self.as_raw(),
119 &raw mut residual,
120 ))?;
121 }
122 Ok(residual)
123 }
124
125 pub fn executed_sweeps(&self, ctx: &Context) -> Result<i32> {
134 ctx.bind()?;
135
136 let mut sweeps = 0;
137 unsafe {
138 try_ffi!(sys::cusolverDnXgesvdjGetSweeps(
139 ctx.as_raw(),
140 self.as_raw(),
141 &raw mut sweeps,
142 ))?;
143 }
144 Ok(sweeps)
145 }
146
147 pub fn as_raw(&self) -> sys::gesvdjInfo_t {
148 self.handle
149 }
150}
151
152impl Drop for GesvdjInfo {
153 fn drop(&mut self) {
154 unsafe {
155 if let Err(err) = try_ffi!(sys::cusolverDnDestroyGesvdjInfo(self.handle)) {
156 #[cfg(debug_assertions)]
157 eprintln!("failed to destroy cusolver gesvdj info: {err}");
158 }
159 }
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub struct Gesvd {
165 pub job_u: SvdMode,
166 pub job_vt: SvdMode,
167 pub rows: usize,
168 pub columns: usize,
169}
170
171impl Gesvd {
172 pub fn new(job_u: SvdMode, job_vt: SvdMode, rows: usize, columns: usize) -> Self {
173 Self {
174 job_u,
175 job_vt,
176 rows,
177 columns,
178 }
179 }
180
181 pub fn workspace_size<
182 TA: DataTypeLike,
183 TS: DataTypeLike,
184 TU: DataTypeLike,
185 TVT: DataTypeLike,
186 >(
187 self,
188 ctx: &Context,
189 params: &Params,
190 input: GesvdInput<'_, TA, TS, TU, TVT>,
191 ) -> Result<WorkspaceSizes> {
192 xgesvd_buffer_size(
193 ctx,
194 params,
195 self.job_u,
196 self.job_vt,
197 self.rows,
198 self.columns,
199 input.a,
200 input.singular_values,
201 input.left_vectors,
202 input.right_vectors_transposed,
203 )
204 }
205
206 pub fn execute<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
207 self,
208 ctx: &Context,
209 params: &Params,
210 bindings: GesvdBindings<'_, TA, TS, TU, TVT>,
211 ) -> Result<()> {
212 xgesvd(
213 ctx,
214 params,
215 self.job_u,
216 self.job_vt,
217 self.rows,
218 self.columns,
219 bindings.a,
220 bindings.singular_values,
221 bindings.left_vectors,
222 bindings.right_vectors_transposed,
223 bindings.workspace,
224 bindings.dev_info,
225 )
226 }
227}
228
229#[derive(Debug, Clone, Copy)]
230pub struct GesvdInput<'a, TA, TS, TU, TVT> {
231 pub a: MatrixRef<'a, TA>,
232 pub singular_values: &'a DeviceMemory<TS>,
233 pub left_vectors: Option<MatrixRef<'a, TU>>,
234 pub right_vectors_transposed: Option<MatrixRef<'a, TVT>>,
235}
236
237#[derive(Debug)]
238pub struct GesvdBindings<'a, TA, TS, TU, TVT> {
239 pub a: MatrixMut<'a, TA>,
240 pub singular_values: &'a mut DeviceMemory<TS>,
241 pub left_vectors: Option<MatrixMut<'a, TU>>,
242 pub right_vectors_transposed: Option<MatrixMut<'a, TVT>>,
243 pub workspace: ByteWorkspaceMut<'a>,
244 pub dev_info: &'a mut DeviceMemory<i32>,
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct Gesvdj {
249 pub mode: EigenMode,
250 pub economy: bool,
251 pub rows: usize,
252 pub columns: usize,
253}
254
255impl Gesvdj {
256 pub fn new(mode: EigenMode, economy: bool, rows: usize, columns: usize) -> Self {
257 Self {
258 mode,
259 economy,
260 rows,
261 columns,
262 }
263 }
264
265 pub fn workspace_size_f32(
266 self,
267 ctx: &Context,
268 input: GesvdjInput<'_, f32, f32>,
269 params: &GesvdjInfo,
270 ) -> Result<usize> {
271 sgesvdj_buffer_size(
272 ctx,
273 self.mode,
274 self.economy,
275 self.rows,
276 self.columns,
277 input.a,
278 input.singular_values,
279 input.left_vectors,
280 input.right_vectors,
281 params,
282 )
283 }
284
285 pub fn execute_f32(
286 self,
287 ctx: &Context,
288 bindings: GesvdjBindings<'_, f32, f32>,
289 params: &GesvdjInfo,
290 ) -> Result<()> {
291 sgesvdj(
292 ctx,
293 self.mode,
294 self.economy,
295 self.rows,
296 self.columns,
297 bindings.a,
298 bindings.singular_values,
299 bindings.left_vectors,
300 bindings.right_vectors,
301 bindings.workspace,
302 bindings.dev_info,
303 params,
304 )
305 }
306
307 pub fn workspace_size_f64(
308 self,
309 ctx: &Context,
310 input: GesvdjInput<'_, f64, f64>,
311 params: &GesvdjInfo,
312 ) -> Result<usize> {
313 dgesvdj_buffer_size(
314 ctx,
315 self.mode,
316 self.economy,
317 self.rows,
318 self.columns,
319 input.a,
320 input.singular_values,
321 input.left_vectors,
322 input.right_vectors,
323 params,
324 )
325 }
326
327 pub fn execute_f64(
328 self,
329 ctx: &Context,
330 bindings: GesvdjBindings<'_, f64, f64>,
331 params: &GesvdjInfo,
332 ) -> Result<()> {
333 dgesvdj(
334 ctx,
335 self.mode,
336 self.economy,
337 self.rows,
338 self.columns,
339 bindings.a,
340 bindings.singular_values,
341 bindings.left_vectors,
342 bindings.right_vectors,
343 bindings.workspace,
344 bindings.dev_info,
345 params,
346 )
347 }
348
349 pub fn workspace_size_complex_f32(
350 self,
351 ctx: &Context,
352 input: GesvdjInput<'_, Complex32, f32>,
353 params: &GesvdjInfo,
354 ) -> Result<usize> {
355 cgesvdj_buffer_size(
356 ctx,
357 self.mode,
358 self.economy,
359 self.rows,
360 self.columns,
361 input.a,
362 input.singular_values,
363 input.left_vectors,
364 input.right_vectors,
365 params,
366 )
367 }
368
369 pub fn execute_complex_f32(
370 self,
371 ctx: &Context,
372 bindings: GesvdjBindings<'_, Complex32, f32>,
373 params: &GesvdjInfo,
374 ) -> Result<()> {
375 cgesvdj(
376 ctx,
377 self.mode,
378 self.economy,
379 self.rows,
380 self.columns,
381 bindings.a,
382 bindings.singular_values,
383 bindings.left_vectors,
384 bindings.right_vectors,
385 bindings.workspace,
386 bindings.dev_info,
387 params,
388 )
389 }
390
391 pub fn workspace_size_complex_f64(
392 self,
393 ctx: &Context,
394 input: GesvdjInput<'_, Complex64, f64>,
395 params: &GesvdjInfo,
396 ) -> Result<usize> {
397 zgesvdj_buffer_size(
398 ctx,
399 self.mode,
400 self.economy,
401 self.rows,
402 self.columns,
403 input.a,
404 input.singular_values,
405 input.left_vectors,
406 input.right_vectors,
407 params,
408 )
409 }
410
411 pub fn execute_complex_f64(
412 self,
413 ctx: &Context,
414 bindings: GesvdjBindings<'_, Complex64, f64>,
415 params: &GesvdjInfo,
416 ) -> Result<()> {
417 zgesvdj(
418 ctx,
419 self.mode,
420 self.economy,
421 self.rows,
422 self.columns,
423 bindings.a,
424 bindings.singular_values,
425 bindings.left_vectors,
426 bindings.right_vectors,
427 bindings.workspace,
428 bindings.dev_info,
429 params,
430 )
431 }
432}
433
434#[derive(Debug, Clone, Copy)]
435pub struct GesvdjInput<'a, TA, TS> {
436 pub a: MatrixRef<'a, TA>,
437 pub singular_values: &'a DeviceMemory<TS>,
438 pub left_vectors: Option<MatrixRef<'a, TA>>,
439 pub right_vectors: Option<MatrixRef<'a, TA>>,
440}
441
442#[derive(Debug)]
443pub struct GesvdjBindings<'a, TA, TS> {
444 pub a: MatrixMut<'a, TA>,
445 pub singular_values: &'a mut DeviceMemory<TS>,
446 pub left_vectors: Option<MatrixMut<'a, TA>>,
447 pub right_vectors: Option<MatrixMut<'a, TA>>,
448 pub workspace: &'a mut DeviceMemory<TA>,
449 pub dev_info: &'a mut DeviceMemory<i32>,
450}
451
452pub fn xgesvd_buffer_size<
453 TA: DataTypeLike,
454 TS: DataTypeLike,
455 TU: DataTypeLike,
456 TVT: DataTypeLike,
457>(
458 ctx: &Context,
459 params: &Params,
460 job_u: SvdMode,
461 job_vt: SvdMode,
462 m: usize,
463 n: usize,
464 a: MatrixRef<'_, TA>,
465 s: &DeviceMemory<TS>,
466 u: Option<MatrixRef<'_, TU>>,
467 vt: Option<MatrixRef<'_, TVT>>,
468) -> Result<WorkspaceSizes> {
469 let a_type = TA::data_type();
470 let s_type = TS::data_type();
471 let u_type = TU::data_type();
472 let vt_type = TVT::data_type();
473 ctx.bind()?;
474 validate_gesvd_dims(m, n)?;
475 validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
476 validate_x_vector(m.min(n), s.byte_len(), s_type)?;
477 validate_x_svd_output(m, m, matrix_ref_parts(u), job_u, u_type)?;
478 validate_x_svd_output(n, n, matrix_ref_parts(vt), job_vt, vt_type)?;
479 if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
480 return Err(Error::InvalidSvdMode);
481 }
482
483 let (u_ptr, ldu) = optional_x_matrix_ptr(matrix_ref_parts(u), m, m, job_u, u_type)?;
484 let (vt_ptr, ldvt) = optional_x_matrix_ptr(matrix_ref_parts(vt), n, n, job_vt, vt_type)?;
485 let mut device_bytes = 0;
486 let mut host_bytes = 0;
487 unsafe {
488 try_ffi!(sys::cusolverDnXgesvd_bufferSize(
489 ctx.as_raw(),
490 params.as_raw(),
491 job_u.as_raw(),
492 job_vt.as_raw(),
493 to_i64(m, "m")?,
494 to_i64(n, "n")?,
495 a_type.into(),
496 a.data.as_ptr().cast(),
497 to_i64(a.leading_dimension, "lda")?,
498 s_type.into(),
499 s.as_ptr().cast(),
500 u_type.into(),
501 u_ptr.cast(),
502 ldu,
503 vt_type.into(),
504 vt_ptr.cast(),
505 ldvt,
506 a_type.into(),
507 &raw mut device_bytes,
508 &raw mut host_bytes,
509 ))?;
510 }
511 Ok(WorkspaceSizes::new(
512 device_bytes as usize,
513 host_bytes as usize,
514 ))
515}
516
517pub fn xgesvd<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
571 ctx: &Context,
572 params: &Params,
573 job_u: SvdMode,
574 job_vt: SvdMode,
575 m: usize,
576 n: usize,
577 a: MatrixMut<'_, TA>,
578 s: &mut DeviceMemory<TS>,
579 u: Option<MatrixMut<'_, TU>>,
580 vt: Option<MatrixMut<'_, TVT>>,
581 workspace: ByteWorkspaceMut<'_>,
582 dev_info: &mut DeviceMemory<i32>,
583) -> Result<()> {
584 let a_type = TA::data_type();
585 let s_type = TS::data_type();
586 let u_type = TU::data_type();
587 let vt_type = TVT::data_type();
588 ctx.bind()?;
589 validate_gesvd_dims(m, n)?;
590 validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
591 validate_x_vector(m.min(n), s.byte_len(), s_type)?;
592 validate_x_svd_output(m, m, matrix_mut_ref_parts(u.as_ref()), job_u, u_type)?;
593 validate_x_svd_output(n, n, matrix_mut_ref_parts(vt.as_ref()), job_vt, vt_type)?;
594 if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
595 return Err(Error::InvalidSvdMode);
596 }
597 require_info_buffer(dev_info)?;
598
599 let workspace_sizes = xgesvd_buffer_size(
600 ctx,
601 params,
602 job_u,
603 job_vt,
604 m,
605 n,
606 a.as_ref(),
607 s,
608 matrix_mut_ref_option(u.as_ref()),
609 matrix_mut_ref_option(vt.as_ref()),
610 )?;
611 require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
612 require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
613
614 let (u_ptr, ldu) = optional_x_matrix_mut_ptr(matrix_mut_parts(u), m, m, job_u, u_type)?;
615 let (vt_ptr, ldvt) = optional_x_matrix_mut_ptr(matrix_mut_parts(vt), n, n, job_vt, vt_type)?;
616 unsafe {
617 try_ffi!(sys::cusolverDnXgesvd(
618 ctx.as_raw(),
619 params.as_raw(),
620 job_u.as_raw(),
621 job_vt.as_raw(),
622 to_i64(m, "m")?,
623 to_i64(n, "n")?,
624 a_type.into(),
625 a.data.as_mut_ptr().cast(),
626 to_i64(a.leading_dimension, "lda")?,
627 s_type.into(),
628 s.as_mut_ptr().cast(),
629 u_type.into(),
630 u_ptr.cast(),
631 ldu,
632 vt_type.into(),
633 vt_ptr.cast(),
634 ldvt,
635 a_type.into(),
636 workspace.device.as_mut_ptr().cast(),
637 workspace_sizes.device_bytes as _,
638 workspace.host.as_mut_ptr().cast(),
639 workspace_sizes.host_bytes as _,
640 dev_info.as_mut_ptr().cast(),
641 ))?;
642 }
643 Ok(())
644}
645
646pub fn xgesvdp_buffer_size<
647 TA: DataTypeLike,
648 TS: DataTypeLike,
649 TU: DataTypeLike,
650 TV: DataTypeLike,
651>(
652 ctx: &Context,
653 params: &Params,
654 jobz: EigenMode,
655 econ: bool,
656 m: usize,
657 n: usize,
658 a: MatrixRef<'_, TA>,
659 s: &DeviceMemory<TS>,
660 u: Option<MatrixRef<'_, TU>>,
661 v: Option<MatrixRef<'_, TV>>,
662) -> Result<WorkspaceSizes> {
663 let a_type = TA::data_type();
664 let s_type = TS::data_type();
665 let u_type = TU::data_type();
666 let v_type = TV::data_type();
667 ctx.bind()?;
668 validate_xgesvdp_inputs(
669 m,
670 n,
671 a.data.byte_len(),
672 a.leading_dimension,
673 a_type,
674 s.byte_len(),
675 s_type,
676 jobz,
677 econ,
678 matrix_ref_parts(u).as_ref(),
679 u_type,
680 matrix_ref_parts(v).as_ref(),
681 v_type,
682 )?;
683 let (u_ptr, ldu) = optional_x_eig_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ, u_type)?;
684 let (v_ptr, ldv) = optional_x_eig_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ, v_type)?;
685 let mut device_bytes = 0;
686 let mut host_bytes = 0;
687 unsafe {
688 try_ffi!(sys::cusolverDnXgesvdp_bufferSize(
689 ctx.as_raw(),
690 params.as_raw(),
691 jobz.into(),
692 i32::from(econ),
693 to_i64(m, "m")?,
694 to_i64(n, "n")?,
695 a_type.into(),
696 a.data.as_ptr().cast(),
697 to_i64(a.leading_dimension, "lda")?,
698 s_type.into(),
699 s.as_ptr().cast(),
700 u_type.into(),
701 u_ptr.cast(),
702 ldu,
703 v_type.into(),
704 v_ptr.cast(),
705 ldv,
706 a_type.into(),
707 &raw mut device_bytes,
708 &raw mut host_bytes,
709 ))?;
710 }
711 Ok(WorkspaceSizes::new(
712 device_bytes as usize,
713 host_bytes as usize,
714 ))
715}
716
717pub fn xgesvdp<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
780 ctx: &Context,
781 params: &Params,
782 jobz: EigenMode,
783 econ: bool,
784 m: usize,
785 n: usize,
786 a: MatrixMut<'_, TA>,
787 s: &mut DeviceMemory<TS>,
788 u: Option<MatrixMut<'_, TU>>,
789 v: Option<MatrixMut<'_, TV>>,
790 workspace: ByteWorkspaceMut<'_>,
791 dev_info: &mut DeviceMemory<i32>,
792 err_sigma: Option<&mut f64>,
793) -> Result<()> {
794 let a_type = TA::data_type();
795 let s_type = TS::data_type();
796 let u_type = TU::data_type();
797 let v_type = TV::data_type();
798 ctx.bind()?;
799 validate_xgesvdp_inputs(
800 m,
801 n,
802 a.data.byte_len(),
803 a.leading_dimension,
804 a_type,
805 s.byte_len(),
806 s_type,
807 jobz,
808 econ,
809 matrix_mut_ref_parts(u.as_ref()).as_ref(),
810 u_type,
811 matrix_mut_ref_parts(v.as_ref()).as_ref(),
812 v_type,
813 )?;
814 require_info_buffer(dev_info)?;
815 let workspace_sizes = xgesvdp_buffer_size(
816 ctx,
817 params,
818 jobz,
819 econ,
820 m,
821 n,
822 a.as_ref(),
823 s,
824 matrix_mut_ref_option(u.as_ref()),
825 matrix_mut_ref_option(v.as_ref()),
826 )?;
827 require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
828 require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
829
830 let (u_ptr, ldu) =
831 optional_x_eig_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ, u_type)?;
832 let (v_ptr, ldv) =
833 optional_x_eig_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ, v_type)?;
834 unsafe {
835 try_ffi!(sys::cusolverDnXgesvdp(
836 ctx.as_raw(),
837 params.as_raw(),
838 jobz.into(),
839 i32::from(econ),
840 to_i64(m, "m")?,
841 to_i64(n, "n")?,
842 a_type.into(),
843 a.data.as_mut_ptr().cast(),
844 to_i64(a.leading_dimension, "lda")?,
845 s_type.into(),
846 s.as_mut_ptr().cast(),
847 u_type.into(),
848 u_ptr.cast(),
849 ldu,
850 v_type.into(),
851 v_ptr.cast(),
852 ldv,
853 a_type.into(),
854 workspace.device.as_mut_ptr().cast(),
855 workspace_sizes.device_bytes as _,
856 workspace.host.as_mut_ptr().cast(),
857 workspace_sizes.host_bytes as _,
858 dev_info.as_mut_ptr().cast(),
859 err_sigma.map_or(ptr::null_mut(), |value| value as *mut f64),
860 ))?;
861 }
862 Ok(())
863}
864
865pub fn xgesvdr_buffer_size<
866 TA: DataTypeLike,
867 TS: DataTypeLike,
868 TU: DataTypeLike,
869 TV: DataTypeLike,
870>(
871 ctx: &Context,
872 params: &Params,
873 job_u: TruncatedSvdMode,
874 job_v: TruncatedSvdMode,
875 m: usize,
876 n: usize,
877 k: usize,
878 p: usize,
879 niters: usize,
880 a: MatrixRef<'_, TA>,
881 s: &DeviceMemory<TS>,
882 u: Option<MatrixRef<'_, TU>>,
883 v: Option<MatrixRef<'_, TV>>,
884) -> Result<WorkspaceSizes> {
885 let a_type = TA::data_type();
886 let s_type = TS::data_type();
887 let u_type = TU::data_type();
888 let v_type = TV::data_type();
889 ctx.bind()?;
890 validate_xgesvdr_inputs(
891 m,
892 n,
893 k,
894 p,
895 niters,
896 a.data.byte_len(),
897 a.leading_dimension,
898 a_type,
899 s.byte_len(),
900 s_type,
901 job_u,
902 matrix_ref_parts(u).as_ref(),
903 u_type,
904 job_v,
905 matrix_ref_parts(v).as_ref(),
906 v_type,
907 )?;
908 let (u_ptr, ldu) = optional_x_truncated_u_ptr(matrix_ref_parts(u), m, k, job_u, u_type)?;
909 let (v_ptr, ldv) = optional_x_truncated_v_ptr(matrix_ref_parts(v), n, k, job_v, v_type)?;
910 let mut device_bytes = 0;
911 let mut host_bytes = 0;
912 unsafe {
913 try_ffi!(sys::cusolverDnXgesvdr_bufferSize(
914 ctx.as_raw(),
915 params.as_raw(),
916 job_u.as_raw(),
917 job_v.as_raw(),
918 to_i64(m, "m")?,
919 to_i64(n, "n")?,
920 to_i64(k, "k")?,
921 to_i64(p, "p")?,
922 to_i64(niters, "niters")?,
923 a_type.into(),
924 a.data.as_ptr().cast(),
925 to_i64(a.leading_dimension, "lda")?,
926 s_type.into(),
927 s.as_ptr().cast(),
928 u_type.into(),
929 u_ptr.cast(),
930 ldu,
931 v_type.into(),
932 v_ptr.cast(),
933 ldv,
934 a_type.into(),
935 &raw mut device_bytes,
936 &raw mut host_bytes,
937 ))?;
938 }
939 Ok(WorkspaceSizes::new(
940 device_bytes as usize,
941 host_bytes as usize,
942 ))
943}
944
945pub fn xgesvdr<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
1012 ctx: &Context,
1013 params: &Params,
1014 job_u: TruncatedSvdMode,
1015 job_v: TruncatedSvdMode,
1016 m: usize,
1017 n: usize,
1018 k: usize,
1019 p: usize,
1020 niters: usize,
1021 a: MatrixMut<'_, TA>,
1022 s: &mut DeviceMemory<TS>,
1023 u: Option<MatrixMut<'_, TU>>,
1024 v: Option<MatrixMut<'_, TV>>,
1025 workspace: ByteWorkspaceMut<'_>,
1026 dev_info: &mut DeviceMemory<i32>,
1027) -> Result<()> {
1028 let a_type = TA::data_type();
1029 let s_type = TS::data_type();
1030 let u_type = TU::data_type();
1031 let v_type = TV::data_type();
1032 ctx.bind()?;
1033 validate_xgesvdr_inputs(
1034 m,
1035 n,
1036 k,
1037 p,
1038 niters,
1039 a.data.byte_len(),
1040 a.leading_dimension,
1041 a_type,
1042 s.byte_len(),
1043 s_type,
1044 job_u,
1045 matrix_mut_ref_parts(u.as_ref()).as_ref(),
1046 u_type,
1047 job_v,
1048 matrix_mut_ref_parts(v.as_ref()).as_ref(),
1049 v_type,
1050 )?;
1051 require_info_buffer(dev_info)?;
1052 let workspace_sizes = xgesvdr_buffer_size(
1053 ctx,
1054 params,
1055 job_u,
1056 job_v,
1057 m,
1058 n,
1059 k,
1060 p,
1061 niters,
1062 a.as_ref(),
1063 s,
1064 matrix_mut_ref_option(u.as_ref()),
1065 matrix_mut_ref_option(v.as_ref()),
1066 )?;
1067 require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
1068 require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
1069 let (u_ptr, ldu) = optional_x_truncated_u_mut_ptr(matrix_mut_parts(u), m, k, job_u, u_type)?;
1070 let (v_ptr, ldv) = optional_x_truncated_v_mut_ptr(matrix_mut_parts(v), n, k, job_v, v_type)?;
1071 unsafe {
1072 try_ffi!(sys::cusolverDnXgesvdr(
1073 ctx.as_raw(),
1074 params.as_raw(),
1075 job_u.as_raw(),
1076 job_v.as_raw(),
1077 to_i64(m, "m")?,
1078 to_i64(n, "n")?,
1079 to_i64(k, "k")?,
1080 to_i64(p, "p")?,
1081 to_i64(niters, "niters")?,
1082 a_type.into(),
1083 a.data.as_mut_ptr().cast(),
1084 to_i64(a.leading_dimension, "lda")?,
1085 s_type.into(),
1086 s.as_mut_ptr().cast(),
1087 u_type.into(),
1088 u_ptr.cast(),
1089 ldu,
1090 v_type.into(),
1091 v_ptr.cast(),
1092 ldv,
1093 a_type.into(),
1094 workspace.device.as_mut_ptr().cast(),
1095 workspace_sizes.device_bytes as _,
1096 workspace.host.as_mut_ptr().cast(),
1097 workspace_sizes.host_bytes as _,
1098 dev_info.as_mut_ptr().cast(),
1099 ))?;
1100 }
1101 Ok(())
1102}
1103
1104pub fn sgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1105 ctx.bind()?;
1106 validate_gesvd_dims(m, n)?;
1107 let mut lwork = 0;
1108 unsafe {
1109 try_ffi!(sys::cusolverDnSgesvd_bufferSize(
1110 ctx.as_raw(),
1111 to_i32(m, "m")?,
1112 to_i32(n, "n")?,
1113 &raw mut lwork,
1114 ))?;
1115 }
1116 to_usize(lwork, "lwork")
1117}
1118
1119pub fn dgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1120 ctx.bind()?;
1121 validate_gesvd_dims(m, n)?;
1122 let mut lwork = 0;
1123 unsafe {
1124 try_ffi!(sys::cusolverDnDgesvd_bufferSize(
1125 ctx.as_raw(),
1126 to_i32(m, "m")?,
1127 to_i32(n, "n")?,
1128 &raw mut lwork,
1129 ))?;
1130 }
1131 to_usize(lwork, "lwork")
1132}
1133
1134pub fn cgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1135 ctx.bind()?;
1136 validate_gesvd_dims(m, n)?;
1137 let mut lwork = 0;
1138 unsafe {
1139 try_ffi!(sys::cusolverDnCgesvd_bufferSize(
1140 ctx.as_raw(),
1141 to_i32(m, "m")?,
1142 to_i32(n, "n")?,
1143 &raw mut lwork,
1144 ))?;
1145 }
1146 to_usize(lwork, "lwork")
1147}
1148
1149pub fn zgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1150 ctx.bind()?;
1151 validate_gesvd_dims(m, n)?;
1152 let mut lwork = 0;
1153 unsafe {
1154 try_ffi!(sys::cusolverDnZgesvd_bufferSize(
1155 ctx.as_raw(),
1156 to_i32(m, "m")?,
1157 to_i32(n, "n")?,
1158 &raw mut lwork,
1159 ))?;
1160 }
1161 to_usize(lwork, "lwork")
1162}
1163
1164pub fn sgesvd(
1199 ctx: &Context,
1200 job_u: SvdMode,
1201 job_vt: SvdMode,
1202 m: usize,
1203 n: usize,
1204 a: MatrixMut<'_, f32>,
1205 s: &mut DeviceMemory<f32>,
1206 u: Option<MatrixMut<'_, f32>>,
1207 vt: Option<MatrixMut<'_, f32>>,
1208 workspace: &mut DeviceMemory<f32>,
1209 rwork: Option<&mut DeviceMemory<f32>>,
1210 dev_info: &mut DeviceMemory<i32>,
1211) -> Result<()> {
1212 ctx.bind()?;
1213 validate_gesvd_inputs(
1214 m,
1215 n,
1216 a.data.len(),
1217 a.leading_dimension,
1218 s.len(),
1219 job_u,
1220 matrix_mut_ref_parts(u.as_ref()).as_ref(),
1221 job_vt,
1222 matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1223 )?;
1224 require_info_buffer(dev_info)?;
1225 require_rwork_buffer(rwork.as_deref(), m, n)?;
1226 let lwork = sgesvd_buffer_size(ctx, m, n)?;
1227 require_workspace(workspace.len(), lwork)?;
1228 let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1229 let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1230 unsafe {
1231 try_ffi!(sys::cusolverDnSgesvd(
1232 ctx.as_raw(),
1233 job_u.as_raw(),
1234 job_vt.as_raw(),
1235 to_i32(m, "m")?,
1236 to_i32(n, "n")?,
1237 a.data.as_mut_ptr().cast(),
1238 to_i32(a.leading_dimension, "lda")?,
1239 s.as_mut_ptr().cast(),
1240 u_ptr.cast(),
1241 ldu,
1242 vt_ptr.cast(),
1243 ldvt,
1244 workspace.as_mut_ptr().cast(),
1245 to_i32(lwork, "lwork")?,
1246 rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1247 dev_info.as_mut_ptr().cast(),
1248 ))?;
1249 }
1250 Ok(())
1251}
1252
1253pub fn dgesvd(
1288 ctx: &Context,
1289 job_u: SvdMode,
1290 job_vt: SvdMode,
1291 m: usize,
1292 n: usize,
1293 a: MatrixMut<'_, f64>,
1294 s: &mut DeviceMemory<f64>,
1295 u: Option<MatrixMut<'_, f64>>,
1296 vt: Option<MatrixMut<'_, f64>>,
1297 workspace: &mut DeviceMemory<f64>,
1298 rwork: Option<&mut DeviceMemory<f64>>,
1299 dev_info: &mut DeviceMemory<i32>,
1300) -> Result<()> {
1301 ctx.bind()?;
1302 validate_gesvd_inputs(
1303 m,
1304 n,
1305 a.data.len(),
1306 a.leading_dimension,
1307 s.len(),
1308 job_u,
1309 matrix_mut_ref_parts(u.as_ref()).as_ref(),
1310 job_vt,
1311 matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1312 )?;
1313 require_info_buffer(dev_info)?;
1314 require_rwork_buffer(rwork.as_deref(), m, n)?;
1315 let lwork = dgesvd_buffer_size(ctx, m, n)?;
1316 require_workspace(workspace.len(), lwork)?;
1317 let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1318 let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1319 unsafe {
1320 try_ffi!(sys::cusolverDnDgesvd(
1321 ctx.as_raw(),
1322 job_u.as_raw(),
1323 job_vt.as_raw(),
1324 to_i32(m, "m")?,
1325 to_i32(n, "n")?,
1326 a.data.as_mut_ptr().cast(),
1327 to_i32(a.leading_dimension, "lda")?,
1328 s.as_mut_ptr().cast(),
1329 u_ptr.cast(),
1330 ldu,
1331 vt_ptr.cast(),
1332 ldvt,
1333 workspace.as_mut_ptr().cast(),
1334 to_i32(lwork, "lwork")?,
1335 rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1336 dev_info.as_mut_ptr().cast(),
1337 ))?;
1338 }
1339 Ok(())
1340}
1341
1342pub fn cgesvd(
1377 ctx: &Context,
1378 job_u: SvdMode,
1379 job_vt: SvdMode,
1380 m: usize,
1381 n: usize,
1382 a: MatrixMut<'_, Complex32>,
1383 s: &mut DeviceMemory<f32>,
1384 u: Option<MatrixMut<'_, Complex32>>,
1385 vt: Option<MatrixMut<'_, Complex32>>,
1386 workspace: &mut DeviceMemory<Complex32>,
1387 rwork: Option<&mut DeviceMemory<f32>>,
1388 dev_info: &mut DeviceMemory<i32>,
1389) -> Result<()> {
1390 ctx.bind()?;
1391 validate_gesvd_inputs(
1392 m,
1393 n,
1394 a.data.len(),
1395 a.leading_dimension,
1396 s.len(),
1397 job_u,
1398 matrix_mut_ref_parts(u.as_ref()).as_ref(),
1399 job_vt,
1400 matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1401 )?;
1402 require_info_buffer(dev_info)?;
1403 require_rwork_buffer(rwork.as_deref(), m, n)?;
1404 let lwork = cgesvd_buffer_size(ctx, m, n)?;
1405 require_workspace(workspace.len(), lwork)?;
1406 let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1407 let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1408 unsafe {
1409 try_ffi!(sys::cusolverDnCgesvd(
1410 ctx.as_raw(),
1411 job_u.as_raw(),
1412 job_vt.as_raw(),
1413 to_i32(m, "m")?,
1414 to_i32(n, "n")?,
1415 a.data.as_mut_ptr().cast(),
1416 to_i32(a.leading_dimension, "lda")?,
1417 s.as_mut_ptr().cast(),
1418 u_ptr.cast(),
1419 ldu,
1420 vt_ptr.cast(),
1421 ldvt,
1422 workspace.as_mut_ptr().cast(),
1423 to_i32(lwork, "lwork")?,
1424 rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1425 dev_info.as_mut_ptr().cast(),
1426 ))?;
1427 }
1428 Ok(())
1429}
1430
1431pub fn zgesvd(
1466 ctx: &Context,
1467 job_u: SvdMode,
1468 job_vt: SvdMode,
1469 m: usize,
1470 n: usize,
1471 a: MatrixMut<'_, Complex64>,
1472 s: &mut DeviceMemory<f64>,
1473 u: Option<MatrixMut<'_, Complex64>>,
1474 vt: Option<MatrixMut<'_, Complex64>>,
1475 workspace: &mut DeviceMemory<Complex64>,
1476 rwork: Option<&mut DeviceMemory<f64>>,
1477 dev_info: &mut DeviceMemory<i32>,
1478) -> Result<()> {
1479 ctx.bind()?;
1480 validate_gesvd_inputs(
1481 m,
1482 n,
1483 a.data.len(),
1484 a.leading_dimension,
1485 s.len(),
1486 job_u,
1487 matrix_mut_ref_parts(u.as_ref()).as_ref(),
1488 job_vt,
1489 matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1490 )?;
1491 require_info_buffer(dev_info)?;
1492 require_rwork_buffer(rwork.as_deref(), m, n)?;
1493 let lwork = zgesvd_buffer_size(ctx, m, n)?;
1494 require_workspace(workspace.len(), lwork)?;
1495 let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1496 let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1497 unsafe {
1498 try_ffi!(sys::cusolverDnZgesvd(
1499 ctx.as_raw(),
1500 job_u.as_raw(),
1501 job_vt.as_raw(),
1502 to_i32(m, "m")?,
1503 to_i32(n, "n")?,
1504 a.data.as_mut_ptr().cast(),
1505 to_i32(a.leading_dimension, "lda")?,
1506 s.as_mut_ptr().cast(),
1507 u_ptr.cast(),
1508 ldu,
1509 vt_ptr.cast(),
1510 ldvt,
1511 workspace.as_mut_ptr().cast(),
1512 to_i32(lwork, "lwork")?,
1513 rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1514 dev_info.as_mut_ptr().cast(),
1515 ))?;
1516 }
1517 Ok(())
1518}
1519
1520pub fn sgesvdj_buffer_size(
1521 ctx: &Context,
1522 jobz: EigenMode,
1523 econ: bool,
1524 m: usize,
1525 n: usize,
1526 a: MatrixRef<'_, f32>,
1527 s: &DeviceMemory<f32>,
1528 u: Option<MatrixRef<'_, f32>>,
1529 v: Option<MatrixRef<'_, f32>>,
1530 params: &GesvdjInfo,
1531) -> Result<usize> {
1532 ctx.bind()?;
1533 validate_gesvdj_inputs(
1534 m,
1535 n,
1536 a.data.len(),
1537 a.leading_dimension,
1538 s.len(),
1539 jobz,
1540 econ,
1541 matrix_ref_parts(u),
1542 matrix_ref_parts(v),
1543 )?;
1544 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1545 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1546 let mut lwork = 0;
1547 unsafe {
1548 try_ffi!(sys::cusolverDnSgesvdj_bufferSize(
1549 ctx.as_raw(),
1550 jobz.into(),
1551 i32::from(econ),
1552 to_i32(m, "m")?,
1553 to_i32(n, "n")?,
1554 a.data.as_ptr().cast(),
1555 to_i32(a.leading_dimension, "lda")?,
1556 s.as_ptr().cast(),
1557 u_ptr.cast(),
1558 ldu,
1559 v_ptr.cast(),
1560 ldv,
1561 &raw mut lwork,
1562 params.as_raw(),
1563 ))?;
1564 }
1565 to_usize(lwork, "lwork")
1566}
1567
1568pub fn dgesvdj_buffer_size(
1569 ctx: &Context,
1570 jobz: EigenMode,
1571 econ: bool,
1572 m: usize,
1573 n: usize,
1574 a: MatrixRef<'_, f64>,
1575 s: &DeviceMemory<f64>,
1576 u: Option<MatrixRef<'_, f64>>,
1577 v: Option<MatrixRef<'_, f64>>,
1578 params: &GesvdjInfo,
1579) -> Result<usize> {
1580 ctx.bind()?;
1581 validate_gesvdj_inputs(
1582 m,
1583 n,
1584 a.data.len(),
1585 a.leading_dimension,
1586 s.len(),
1587 jobz,
1588 econ,
1589 matrix_ref_parts(u),
1590 matrix_ref_parts(v),
1591 )?;
1592 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1593 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1594 let mut lwork = 0;
1595 unsafe {
1596 try_ffi!(sys::cusolverDnDgesvdj_bufferSize(
1597 ctx.as_raw(),
1598 jobz.into(),
1599 i32::from(econ),
1600 to_i32(m, "m")?,
1601 to_i32(n, "n")?,
1602 a.data.as_ptr().cast(),
1603 to_i32(a.leading_dimension, "lda")?,
1604 s.as_ptr().cast(),
1605 u_ptr.cast(),
1606 ldu,
1607 v_ptr.cast(),
1608 ldv,
1609 &raw mut lwork,
1610 params.as_raw(),
1611 ))?;
1612 }
1613 to_usize(lwork, "lwork")
1614}
1615
1616pub fn cgesvdj_buffer_size(
1617 ctx: &Context,
1618 jobz: EigenMode,
1619 econ: bool,
1620 m: usize,
1621 n: usize,
1622 a: MatrixRef<'_, Complex32>,
1623 s: &DeviceMemory<f32>,
1624 u: Option<MatrixRef<'_, Complex32>>,
1625 v: Option<MatrixRef<'_, Complex32>>,
1626 params: &GesvdjInfo,
1627) -> Result<usize> {
1628 ctx.bind()?;
1629 validate_gesvdj_inputs(
1630 m,
1631 n,
1632 a.data.len(),
1633 a.leading_dimension,
1634 s.len(),
1635 jobz,
1636 econ,
1637 matrix_ref_parts(u),
1638 matrix_ref_parts(v),
1639 )?;
1640 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1641 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1642 let mut lwork = 0;
1643 unsafe {
1644 try_ffi!(sys::cusolverDnCgesvdj_bufferSize(
1645 ctx.as_raw(),
1646 jobz.into(),
1647 i32::from(econ),
1648 to_i32(m, "m")?,
1649 to_i32(n, "n")?,
1650 a.data.as_ptr().cast(),
1651 to_i32(a.leading_dimension, "lda")?,
1652 s.as_ptr().cast(),
1653 u_ptr.cast(),
1654 ldu,
1655 v_ptr.cast(),
1656 ldv,
1657 &raw mut lwork,
1658 params.as_raw(),
1659 ))?;
1660 }
1661 to_usize(lwork, "lwork")
1662}
1663
1664pub fn zgesvdj_buffer_size(
1665 ctx: &Context,
1666 jobz: EigenMode,
1667 econ: bool,
1668 m: usize,
1669 n: usize,
1670 a: MatrixRef<'_, Complex64>,
1671 s: &DeviceMemory<f64>,
1672 u: Option<MatrixRef<'_, Complex64>>,
1673 v: Option<MatrixRef<'_, Complex64>>,
1674 params: &GesvdjInfo,
1675) -> Result<usize> {
1676 ctx.bind()?;
1677 validate_gesvdj_inputs(
1678 m,
1679 n,
1680 a.data.len(),
1681 a.leading_dimension,
1682 s.len(),
1683 jobz,
1684 econ,
1685 matrix_ref_parts(u),
1686 matrix_ref_parts(v),
1687 )?;
1688 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1689 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1690 let mut lwork = 0;
1691 unsafe {
1692 try_ffi!(sys::cusolverDnZgesvdj_bufferSize(
1693 ctx.as_raw(),
1694 jobz.into(),
1695 i32::from(econ),
1696 to_i32(m, "m")?,
1697 to_i32(n, "n")?,
1698 a.data.as_ptr().cast(),
1699 to_i32(a.leading_dimension, "lda")?,
1700 s.as_ptr().cast(),
1701 u_ptr.cast(),
1702 ldu,
1703 v_ptr.cast(),
1704 ldv,
1705 &raw mut lwork,
1706 params.as_raw(),
1707 ))?;
1708 }
1709 to_usize(lwork, "lwork")
1710}
1711
1712pub fn sgesvdj(
1769 ctx: &Context,
1770 jobz: EigenMode,
1771 econ: bool,
1772 m: usize,
1773 n: usize,
1774 a: MatrixMut<'_, f32>,
1775 s: &mut DeviceMemory<f32>,
1776 u: Option<MatrixMut<'_, f32>>,
1777 v: Option<MatrixMut<'_, f32>>,
1778 workspace: &mut DeviceMemory<f32>,
1779 dev_info: &mut DeviceMemory<i32>,
1780 params: &GesvdjInfo,
1781) -> Result<()> {
1782 ctx.bind()?;
1783 validate_gesvdj_inputs(
1784 m,
1785 n,
1786 a.data.len(),
1787 a.leading_dimension,
1788 s.len(),
1789 jobz,
1790 econ,
1791 matrix_mut_ref_parts(u.as_ref()),
1792 matrix_mut_ref_parts(v.as_ref()),
1793 )?;
1794 require_info_buffer(dev_info)?;
1795 let lwork = sgesvdj_buffer_size(
1796 ctx,
1797 jobz,
1798 econ,
1799 m,
1800 n,
1801 a.as_ref(),
1802 s,
1803 matrix_mut_ref_option(u.as_ref()),
1804 matrix_mut_ref_option(v.as_ref()),
1805 params,
1806 )?;
1807 require_workspace(workspace.len(), lwork)?;
1808 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
1809 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
1810 unsafe {
1811 try_ffi!(sys::cusolverDnSgesvdj(
1812 ctx.as_raw(),
1813 jobz.into(),
1814 i32::from(econ),
1815 to_i32(m, "m")?,
1816 to_i32(n, "n")?,
1817 a.data.as_mut_ptr().cast(),
1818 to_i32(a.leading_dimension, "lda")?,
1819 s.as_mut_ptr().cast(),
1820 u_ptr.cast(),
1821 ldu,
1822 v_ptr.cast(),
1823 ldv,
1824 workspace.as_mut_ptr().cast(),
1825 to_i32(lwork, "lwork")?,
1826 dev_info.as_mut_ptr().cast(),
1827 params.as_raw(),
1828 ))?;
1829 }
1830 Ok(())
1831}
1832
1833pub fn dgesvdj(
1890 ctx: &Context,
1891 jobz: EigenMode,
1892 econ: bool,
1893 m: usize,
1894 n: usize,
1895 a: MatrixMut<'_, f64>,
1896 s: &mut DeviceMemory<f64>,
1897 u: Option<MatrixMut<'_, f64>>,
1898 v: Option<MatrixMut<'_, f64>>,
1899 workspace: &mut DeviceMemory<f64>,
1900 dev_info: &mut DeviceMemory<i32>,
1901 params: &GesvdjInfo,
1902) -> Result<()> {
1903 ctx.bind()?;
1904 validate_gesvdj_inputs(
1905 m,
1906 n,
1907 a.data.len(),
1908 a.leading_dimension,
1909 s.len(),
1910 jobz,
1911 econ,
1912 matrix_mut_ref_parts(u.as_ref()),
1913 matrix_mut_ref_parts(v.as_ref()),
1914 )?;
1915 require_info_buffer(dev_info)?;
1916 let lwork = dgesvdj_buffer_size(
1917 ctx,
1918 jobz,
1919 econ,
1920 m,
1921 n,
1922 a.as_ref(),
1923 s,
1924 matrix_mut_ref_option(u.as_ref()),
1925 matrix_mut_ref_option(v.as_ref()),
1926 params,
1927 )?;
1928 require_workspace(workspace.len(), lwork)?;
1929 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
1930 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
1931 unsafe {
1932 try_ffi!(sys::cusolverDnDgesvdj(
1933 ctx.as_raw(),
1934 jobz.into(),
1935 i32::from(econ),
1936 to_i32(m, "m")?,
1937 to_i32(n, "n")?,
1938 a.data.as_mut_ptr().cast(),
1939 to_i32(a.leading_dimension, "lda")?,
1940 s.as_mut_ptr().cast(),
1941 u_ptr.cast(),
1942 ldu,
1943 v_ptr.cast(),
1944 ldv,
1945 workspace.as_mut_ptr().cast(),
1946 to_i32(lwork, "lwork")?,
1947 dev_info.as_mut_ptr().cast(),
1948 params.as_raw(),
1949 ))?;
1950 }
1951 Ok(())
1952}
1953
1954pub fn cgesvdj(
2011 ctx: &Context,
2012 jobz: EigenMode,
2013 econ: bool,
2014 m: usize,
2015 n: usize,
2016 a: MatrixMut<'_, Complex32>,
2017 s: &mut DeviceMemory<f32>,
2018 u: Option<MatrixMut<'_, Complex32>>,
2019 v: Option<MatrixMut<'_, Complex32>>,
2020 workspace: &mut DeviceMemory<Complex32>,
2021 dev_info: &mut DeviceMemory<i32>,
2022 params: &GesvdjInfo,
2023) -> Result<()> {
2024 ctx.bind()?;
2025 validate_gesvdj_inputs(
2026 m,
2027 n,
2028 a.data.len(),
2029 a.leading_dimension,
2030 s.len(),
2031 jobz,
2032 econ,
2033 matrix_mut_ref_parts(u.as_ref()),
2034 matrix_mut_ref_parts(v.as_ref()),
2035 )?;
2036 require_info_buffer(dev_info)?;
2037 let lwork = cgesvdj_buffer_size(
2038 ctx,
2039 jobz,
2040 econ,
2041 m,
2042 n,
2043 a.as_ref(),
2044 s,
2045 matrix_mut_ref_option(u.as_ref()),
2046 matrix_mut_ref_option(v.as_ref()),
2047 params,
2048 )?;
2049 require_workspace(workspace.len(), lwork)?;
2050 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
2051 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
2052 unsafe {
2053 try_ffi!(sys::cusolverDnCgesvdj(
2054 ctx.as_raw(),
2055 jobz.into(),
2056 i32::from(econ),
2057 to_i32(m, "m")?,
2058 to_i32(n, "n")?,
2059 a.data.as_mut_ptr().cast(),
2060 to_i32(a.leading_dimension, "lda")?,
2061 s.as_mut_ptr().cast(),
2062 u_ptr.cast(),
2063 ldu,
2064 v_ptr.cast(),
2065 ldv,
2066 workspace.as_mut_ptr().cast(),
2067 to_i32(lwork, "lwork")?,
2068 dev_info.as_mut_ptr().cast(),
2069 params.as_raw(),
2070 ))?;
2071 }
2072 Ok(())
2073}
2074
2075pub fn zgesvdj(
2132 ctx: &Context,
2133 jobz: EigenMode,
2134 econ: bool,
2135 m: usize,
2136 n: usize,
2137 a: MatrixMut<'_, Complex64>,
2138 s: &mut DeviceMemory<f64>,
2139 u: Option<MatrixMut<'_, Complex64>>,
2140 v: Option<MatrixMut<'_, Complex64>>,
2141 workspace: &mut DeviceMemory<Complex64>,
2142 dev_info: &mut DeviceMemory<i32>,
2143 params: &GesvdjInfo,
2144) -> Result<()> {
2145 ctx.bind()?;
2146 validate_gesvdj_inputs(
2147 m,
2148 n,
2149 a.data.len(),
2150 a.leading_dimension,
2151 s.len(),
2152 jobz,
2153 econ,
2154 matrix_mut_ref_parts(u.as_ref()),
2155 matrix_mut_ref_parts(v.as_ref()),
2156 )?;
2157 require_info_buffer(dev_info)?;
2158 let lwork = zgesvdj_buffer_size(
2159 ctx,
2160 jobz,
2161 econ,
2162 m,
2163 n,
2164 a.as_ref(),
2165 s,
2166 matrix_mut_ref_option(u.as_ref()),
2167 matrix_mut_ref_option(v.as_ref()),
2168 params,
2169 )?;
2170 require_workspace(workspace.len(), lwork)?;
2171 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
2172 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
2173 unsafe {
2174 try_ffi!(sys::cusolverDnZgesvdj(
2175 ctx.as_raw(),
2176 jobz.into(),
2177 i32::from(econ),
2178 to_i32(m, "m")?,
2179 to_i32(n, "n")?,
2180 a.data.as_mut_ptr().cast(),
2181 to_i32(a.leading_dimension, "lda")?,
2182 s.as_mut_ptr().cast(),
2183 u_ptr.cast(),
2184 ldu,
2185 v_ptr.cast(),
2186 ldv,
2187 workspace.as_mut_ptr().cast(),
2188 to_i32(lwork, "lwork")?,
2189 dev_info.as_mut_ptr().cast(),
2190 params.as_raw(),
2191 ))?;
2192 }
2193 Ok(())
2194}
2195
2196pub fn sgesvdj_batched_buffer_size(
2197 ctx: &Context,
2198 jobz: EigenMode,
2199 m: usize,
2200 n: usize,
2201 a: MatrixRef<'_, f32>,
2202 s: &DeviceMemory<f32>,
2203 u: Option<MatrixRef<'_, f32>>,
2204 v: Option<MatrixRef<'_, f32>>,
2205 params: &GesvdjInfo,
2206 batch_size: usize,
2207) -> Result<usize> {
2208 ctx.bind()?;
2209 validate_gesvdj_batched_inputs(
2210 m,
2211 n,
2212 a.data.len(),
2213 a.leading_dimension,
2214 s.len(),
2215 jobz,
2216 matrix_ref_parts(u),
2217 matrix_ref_parts(v),
2218 batch_size,
2219 )?;
2220 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2221 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2222 let mut lwork = 0;
2223 unsafe {
2224 try_ffi!(sys::cusolverDnSgesvdjBatched_bufferSize(
2225 ctx.as_raw(),
2226 jobz.into(),
2227 to_i32(m, "m")?,
2228 to_i32(n, "n")?,
2229 a.data.as_ptr().cast(),
2230 to_i32(a.leading_dimension, "lda")?,
2231 s.as_ptr().cast(),
2232 u_ptr.cast(),
2233 ldu,
2234 v_ptr.cast(),
2235 ldv,
2236 &raw mut lwork,
2237 params.as_raw(),
2238 to_i32(batch_size, "batch_size")?,
2239 ))?;
2240 }
2241 to_usize(lwork, "lwork")
2242}
2243
2244pub fn dgesvdj_batched_buffer_size(
2245 ctx: &Context,
2246 jobz: EigenMode,
2247 m: usize,
2248 n: usize,
2249 a: MatrixRef<'_, f64>,
2250 s: &DeviceMemory<f64>,
2251 u: Option<MatrixRef<'_, f64>>,
2252 v: Option<MatrixRef<'_, f64>>,
2253 params: &GesvdjInfo,
2254 batch_size: usize,
2255) -> Result<usize> {
2256 ctx.bind()?;
2257 validate_gesvdj_batched_inputs(
2258 m,
2259 n,
2260 a.data.len(),
2261 a.leading_dimension,
2262 s.len(),
2263 jobz,
2264 matrix_ref_parts(u),
2265 matrix_ref_parts(v),
2266 batch_size,
2267 )?;
2268 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2269 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2270 let mut lwork = 0;
2271 unsafe {
2272 try_ffi!(sys::cusolverDnDgesvdjBatched_bufferSize(
2273 ctx.as_raw(),
2274 jobz.into(),
2275 to_i32(m, "m")?,
2276 to_i32(n, "n")?,
2277 a.data.as_ptr().cast(),
2278 to_i32(a.leading_dimension, "lda")?,
2279 s.as_ptr().cast(),
2280 u_ptr.cast(),
2281 ldu,
2282 v_ptr.cast(),
2283 ldv,
2284 &raw mut lwork,
2285 params.as_raw(),
2286 to_i32(batch_size, "batch_size")?,
2287 ))?;
2288 }
2289 to_usize(lwork, "lwork")
2290}
2291
2292pub fn cgesvdj_batched_buffer_size(
2293 ctx: &Context,
2294 jobz: EigenMode,
2295 m: usize,
2296 n: usize,
2297 a: MatrixRef<'_, Complex32>,
2298 s: &DeviceMemory<f32>,
2299 u: Option<MatrixRef<'_, Complex32>>,
2300 v: Option<MatrixRef<'_, Complex32>>,
2301 params: &GesvdjInfo,
2302 batch_size: usize,
2303) -> Result<usize> {
2304 ctx.bind()?;
2305 validate_gesvdj_batched_inputs(
2306 m,
2307 n,
2308 a.data.len(),
2309 a.leading_dimension,
2310 s.len(),
2311 jobz,
2312 matrix_ref_parts(u),
2313 matrix_ref_parts(v),
2314 batch_size,
2315 )?;
2316 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2317 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2318 let mut lwork = 0;
2319 unsafe {
2320 try_ffi!(sys::cusolverDnCgesvdjBatched_bufferSize(
2321 ctx.as_raw(),
2322 jobz.into(),
2323 to_i32(m, "m")?,
2324 to_i32(n, "n")?,
2325 a.data.as_ptr().cast(),
2326 to_i32(a.leading_dimension, "lda")?,
2327 s.as_ptr().cast(),
2328 u_ptr.cast(),
2329 ldu,
2330 v_ptr.cast(),
2331 ldv,
2332 &raw mut lwork,
2333 params.as_raw(),
2334 to_i32(batch_size, "batch_size")?,
2335 ))?;
2336 }
2337 to_usize(lwork, "lwork")
2338}
2339
2340pub fn zgesvdj_batched_buffer_size(
2341 ctx: &Context,
2342 jobz: EigenMode,
2343 m: usize,
2344 n: usize,
2345 a: MatrixRef<'_, Complex64>,
2346 s: &DeviceMemory<f64>,
2347 u: Option<MatrixRef<'_, Complex64>>,
2348 v: Option<MatrixRef<'_, Complex64>>,
2349 params: &GesvdjInfo,
2350 batch_size: usize,
2351) -> Result<usize> {
2352 ctx.bind()?;
2353 validate_gesvdj_batched_inputs(
2354 m,
2355 n,
2356 a.data.len(),
2357 a.leading_dimension,
2358 s.len(),
2359 jobz,
2360 matrix_ref_parts(u),
2361 matrix_ref_parts(v),
2362 batch_size,
2363 )?;
2364 let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2365 let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2366 let mut lwork = 0;
2367 unsafe {
2368 try_ffi!(sys::cusolverDnZgesvdjBatched_bufferSize(
2369 ctx.as_raw(),
2370 jobz.into(),
2371 to_i32(m, "m")?,
2372 to_i32(n, "n")?,
2373 a.data.as_ptr().cast(),
2374 to_i32(a.leading_dimension, "lda")?,
2375 s.as_ptr().cast(),
2376 u_ptr.cast(),
2377 ldu,
2378 v_ptr.cast(),
2379 ldv,
2380 &raw mut lwork,
2381 params.as_raw(),
2382 to_i32(batch_size, "batch_size")?,
2383 ))?;
2384 }
2385 to_usize(lwork, "lwork")
2386}
2387
2388pub fn sgesvdj_batched(
2429 ctx: &Context,
2430 jobz: EigenMode,
2431 m: usize,
2432 n: usize,
2433 a: MatrixMut<'_, f32>,
2434 s: &mut DeviceMemory<f32>,
2435 u: Option<MatrixMut<'_, f32>>,
2436 v: Option<MatrixMut<'_, f32>>,
2437 workspace: &mut DeviceMemory<f32>,
2438 dev_info: &mut DeviceMemory<i32>,
2439 params: &GesvdjInfo,
2440 batch_size: usize,
2441) -> Result<()> {
2442 ctx.bind()?;
2443 validate_gesvdj_batched_inputs(
2444 m,
2445 n,
2446 a.data.len(),
2447 a.leading_dimension,
2448 s.len(),
2449 jobz,
2450 matrix_mut_ref_parts(u.as_ref()),
2451 matrix_mut_ref_parts(v.as_ref()),
2452 batch_size,
2453 )?;
2454 require_info_buffer_len(dev_info, batch_size)?;
2455 let lwork = sgesvdj_batched_buffer_size(
2456 ctx,
2457 jobz,
2458 m,
2459 n,
2460 a.as_ref(),
2461 s,
2462 matrix_mut_ref_option(u.as_ref()),
2463 matrix_mut_ref_option(v.as_ref()),
2464 params,
2465 batch_size,
2466 )?;
2467 require_workspace(workspace.len(), lwork)?;
2468 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2469 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2470 unsafe {
2471 try_ffi!(sys::cusolverDnSgesvdjBatched(
2472 ctx.as_raw(),
2473 jobz.into(),
2474 to_i32(m, "m")?,
2475 to_i32(n, "n")?,
2476 a.data.as_mut_ptr().cast(),
2477 to_i32(a.leading_dimension, "lda")?,
2478 s.as_mut_ptr().cast(),
2479 u_ptr.cast(),
2480 ldu,
2481 v_ptr.cast(),
2482 ldv,
2483 workspace.as_mut_ptr().cast(),
2484 to_i32(lwork, "lwork")?,
2485 dev_info.as_mut_ptr().cast(),
2486 params.as_raw(),
2487 to_i32(batch_size, "batch_size")?,
2488 ))?;
2489 }
2490 Ok(())
2491}
2492
2493pub fn dgesvdj_batched(
2534 ctx: &Context,
2535 jobz: EigenMode,
2536 m: usize,
2537 n: usize,
2538 a: MatrixMut<'_, f64>,
2539 s: &mut DeviceMemory<f64>,
2540 u: Option<MatrixMut<'_, f64>>,
2541 v: Option<MatrixMut<'_, f64>>,
2542 workspace: &mut DeviceMemory<f64>,
2543 dev_info: &mut DeviceMemory<i32>,
2544 params: &GesvdjInfo,
2545 batch_size: usize,
2546) -> Result<()> {
2547 ctx.bind()?;
2548 validate_gesvdj_batched_inputs(
2549 m,
2550 n,
2551 a.data.len(),
2552 a.leading_dimension,
2553 s.len(),
2554 jobz,
2555 matrix_mut_ref_parts(u.as_ref()),
2556 matrix_mut_ref_parts(v.as_ref()),
2557 batch_size,
2558 )?;
2559 require_info_buffer_len(dev_info, batch_size)?;
2560 let lwork = dgesvdj_batched_buffer_size(
2561 ctx,
2562 jobz,
2563 m,
2564 n,
2565 a.as_ref(),
2566 s,
2567 matrix_mut_ref_option(u.as_ref()),
2568 matrix_mut_ref_option(v.as_ref()),
2569 params,
2570 batch_size,
2571 )?;
2572 require_workspace(workspace.len(), lwork)?;
2573 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2574 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2575 unsafe {
2576 try_ffi!(sys::cusolverDnDgesvdjBatched(
2577 ctx.as_raw(),
2578 jobz.into(),
2579 to_i32(m, "m")?,
2580 to_i32(n, "n")?,
2581 a.data.as_mut_ptr().cast(),
2582 to_i32(a.leading_dimension, "lda")?,
2583 s.as_mut_ptr().cast(),
2584 u_ptr.cast(),
2585 ldu,
2586 v_ptr.cast(),
2587 ldv,
2588 workspace.as_mut_ptr().cast(),
2589 to_i32(lwork, "lwork")?,
2590 dev_info.as_mut_ptr().cast(),
2591 params.as_raw(),
2592 to_i32(batch_size, "batch_size")?,
2593 ))?;
2594 }
2595 Ok(())
2596}
2597
2598pub fn cgesvdj_batched(
2639 ctx: &Context,
2640 jobz: EigenMode,
2641 m: usize,
2642 n: usize,
2643 a: MatrixMut<'_, Complex32>,
2644 s: &mut DeviceMemory<f32>,
2645 u: Option<MatrixMut<'_, Complex32>>,
2646 v: Option<MatrixMut<'_, Complex32>>,
2647 workspace: &mut DeviceMemory<Complex32>,
2648 dev_info: &mut DeviceMemory<i32>,
2649 params: &GesvdjInfo,
2650 batch_size: usize,
2651) -> Result<()> {
2652 ctx.bind()?;
2653 validate_gesvdj_batched_inputs(
2654 m,
2655 n,
2656 a.data.len(),
2657 a.leading_dimension,
2658 s.len(),
2659 jobz,
2660 matrix_mut_ref_parts(u.as_ref()),
2661 matrix_mut_ref_parts(v.as_ref()),
2662 batch_size,
2663 )?;
2664 require_info_buffer_len(dev_info, batch_size)?;
2665 let lwork = cgesvdj_batched_buffer_size(
2666 ctx,
2667 jobz,
2668 m,
2669 n,
2670 a.as_ref(),
2671 s,
2672 matrix_mut_ref_option(u.as_ref()),
2673 matrix_mut_ref_option(v.as_ref()),
2674 params,
2675 batch_size,
2676 )?;
2677 require_workspace(workspace.len(), lwork)?;
2678 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2679 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2680 unsafe {
2681 try_ffi!(sys::cusolverDnCgesvdjBatched(
2682 ctx.as_raw(),
2683 jobz.into(),
2684 to_i32(m, "m")?,
2685 to_i32(n, "n")?,
2686 a.data.as_mut_ptr().cast(),
2687 to_i32(a.leading_dimension, "lda")?,
2688 s.as_mut_ptr().cast(),
2689 u_ptr.cast(),
2690 ldu,
2691 v_ptr.cast(),
2692 ldv,
2693 workspace.as_mut_ptr().cast(),
2694 to_i32(lwork, "lwork")?,
2695 dev_info.as_mut_ptr().cast(),
2696 params.as_raw(),
2697 to_i32(batch_size, "batch_size")?,
2698 ))?;
2699 }
2700 Ok(())
2701}
2702
2703pub fn zgesvdj_batched(
2744 ctx: &Context,
2745 jobz: EigenMode,
2746 m: usize,
2747 n: usize,
2748 a: MatrixMut<'_, Complex64>,
2749 s: &mut DeviceMemory<f64>,
2750 u: Option<MatrixMut<'_, Complex64>>,
2751 v: Option<MatrixMut<'_, Complex64>>,
2752 workspace: &mut DeviceMemory<Complex64>,
2753 dev_info: &mut DeviceMemory<i32>,
2754 params: &GesvdjInfo,
2755 batch_size: usize,
2756) -> Result<()> {
2757 ctx.bind()?;
2758 validate_gesvdj_batched_inputs(
2759 m,
2760 n,
2761 a.data.len(),
2762 a.leading_dimension,
2763 s.len(),
2764 jobz,
2765 matrix_mut_ref_parts(u.as_ref()),
2766 matrix_mut_ref_parts(v.as_ref()),
2767 batch_size,
2768 )?;
2769 require_info_buffer_len(dev_info, batch_size)?;
2770 let lwork = zgesvdj_batched_buffer_size(
2771 ctx,
2772 jobz,
2773 m,
2774 n,
2775 a.as_ref(),
2776 s,
2777 matrix_mut_ref_option(u.as_ref()),
2778 matrix_mut_ref_option(v.as_ref()),
2779 params,
2780 batch_size,
2781 )?;
2782 require_workspace(workspace.len(), lwork)?;
2783 let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2784 let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2785 unsafe {
2786 try_ffi!(sys::cusolverDnZgesvdjBatched(
2787 ctx.as_raw(),
2788 jobz.into(),
2789 to_i32(m, "m")?,
2790 to_i32(n, "n")?,
2791 a.data.as_mut_ptr().cast(),
2792 to_i32(a.leading_dimension, "lda")?,
2793 s.as_mut_ptr().cast(),
2794 u_ptr.cast(),
2795 ldu,
2796 v_ptr.cast(),
2797 ldv,
2798 workspace.as_mut_ptr().cast(),
2799 to_i32(lwork, "lwork")?,
2800 dev_info.as_mut_ptr().cast(),
2801 params.as_raw(),
2802 to_i32(batch_size, "batch_size")?,
2803 ))?;
2804 }
2805 Ok(())
2806}
2807
2808pub fn sgesvda_strided_batched_buffer_size(
2809 ctx: &Context,
2810 jobz: EigenMode,
2811 rank: usize,
2812 m: usize,
2813 n: usize,
2814 a: StridedBatchedMatrixRef<'_, f32>,
2815 s: StridedBatchedVectorRef<'_, f32>,
2816 u: Option<StridedBatchedMatrixRef<'_, f32>>,
2817 v: Option<StridedBatchedMatrixRef<'_, f32>>,
2818 batch_size: usize,
2819) -> Result<usize> {
2820 ctx.bind()?;
2821 validate_gesvda_strided_batched_inputs(
2822 rank,
2823 m,
2824 n,
2825 a.data.len(),
2826 a.leading_dimension,
2827 a.stride,
2828 s.data.len(),
2829 s.stride,
2830 jobz,
2831 strided_batched_matrix_ref_parts(u),
2832 strided_batched_matrix_ref_parts(v),
2833 batch_size,
2834 )?;
2835 let (u_ptr, ldu, stride_u) =
2836 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2837 let (v_ptr, ldv, stride_v) =
2838 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2839 let mut lwork = 0;
2840 unsafe {
2841 try_ffi!(sys::cusolverDnSgesvdaStridedBatched_bufferSize(
2842 ctx.as_raw(),
2843 jobz.into(),
2844 to_i32(rank, "rank")?,
2845 to_i32(m, "m")?,
2846 to_i32(n, "n")?,
2847 a.data.as_ptr().cast(),
2848 to_i32(a.leading_dimension, "lda")?,
2849 to_i64(a.stride, "stride_a")?,
2850 s.data.as_ptr().cast(),
2851 to_i64(s.stride, "stride_s")?,
2852 u_ptr.cast(),
2853 ldu,
2854 stride_u,
2855 v_ptr.cast(),
2856 ldv,
2857 stride_v,
2858 &raw mut lwork,
2859 to_i32(batch_size, "batch_size")?,
2860 ))?;
2861 }
2862 to_usize(lwork, "lwork")
2863}
2864
2865pub fn dgesvda_strided_batched_buffer_size(
2866 ctx: &Context,
2867 jobz: EigenMode,
2868 rank: usize,
2869 m: usize,
2870 n: usize,
2871 a: StridedBatchedMatrixRef<'_, f64>,
2872 s: StridedBatchedVectorRef<'_, f64>,
2873 u: Option<StridedBatchedMatrixRef<'_, f64>>,
2874 v: Option<StridedBatchedMatrixRef<'_, f64>>,
2875 batch_size: usize,
2876) -> Result<usize> {
2877 ctx.bind()?;
2878 validate_gesvda_strided_batched_inputs(
2879 rank,
2880 m,
2881 n,
2882 a.data.len(),
2883 a.leading_dimension,
2884 a.stride,
2885 s.data.len(),
2886 s.stride,
2887 jobz,
2888 strided_batched_matrix_ref_parts(u),
2889 strided_batched_matrix_ref_parts(v),
2890 batch_size,
2891 )?;
2892 let (u_ptr, ldu, stride_u) =
2893 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2894 let (v_ptr, ldv, stride_v) =
2895 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2896 let mut lwork = 0;
2897 unsafe {
2898 try_ffi!(sys::cusolverDnDgesvdaStridedBatched_bufferSize(
2899 ctx.as_raw(),
2900 jobz.into(),
2901 to_i32(rank, "rank")?,
2902 to_i32(m, "m")?,
2903 to_i32(n, "n")?,
2904 a.data.as_ptr().cast(),
2905 to_i32(a.leading_dimension, "lda")?,
2906 to_i64(a.stride, "stride_a")?,
2907 s.data.as_ptr().cast(),
2908 to_i64(s.stride, "stride_s")?,
2909 u_ptr.cast(),
2910 ldu,
2911 stride_u,
2912 v_ptr.cast(),
2913 ldv,
2914 stride_v,
2915 &raw mut lwork,
2916 to_i32(batch_size, "batch_size")?,
2917 ))?;
2918 }
2919 to_usize(lwork, "lwork")
2920}
2921
2922pub fn cgesvda_strided_batched_buffer_size(
2923 ctx: &Context,
2924 jobz: EigenMode,
2925 rank: usize,
2926 m: usize,
2927 n: usize,
2928 a: StridedBatchedMatrixRef<'_, Complex32>,
2929 s: StridedBatchedVectorRef<'_, f32>,
2930 u: Option<StridedBatchedMatrixRef<'_, Complex32>>,
2931 v: Option<StridedBatchedMatrixRef<'_, Complex32>>,
2932 batch_size: usize,
2933) -> Result<usize> {
2934 ctx.bind()?;
2935 validate_gesvda_strided_batched_inputs(
2936 rank,
2937 m,
2938 n,
2939 a.data.len(),
2940 a.leading_dimension,
2941 a.stride,
2942 s.data.len(),
2943 s.stride,
2944 jobz,
2945 strided_batched_matrix_ref_parts(u),
2946 strided_batched_matrix_ref_parts(v),
2947 batch_size,
2948 )?;
2949 let (u_ptr, ldu, stride_u) =
2950 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2951 let (v_ptr, ldv, stride_v) =
2952 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2953 let mut lwork = 0;
2954 unsafe {
2955 try_ffi!(sys::cusolverDnCgesvdaStridedBatched_bufferSize(
2956 ctx.as_raw(),
2957 jobz.into(),
2958 to_i32(rank, "rank")?,
2959 to_i32(m, "m")?,
2960 to_i32(n, "n")?,
2961 a.data.as_ptr().cast(),
2962 to_i32(a.leading_dimension, "lda")?,
2963 to_i64(a.stride, "stride_a")?,
2964 s.data.as_ptr().cast(),
2965 to_i64(s.stride, "stride_s")?,
2966 u_ptr.cast(),
2967 ldu,
2968 stride_u,
2969 v_ptr.cast(),
2970 ldv,
2971 stride_v,
2972 &raw mut lwork,
2973 to_i32(batch_size, "batch_size")?,
2974 ))?;
2975 }
2976 to_usize(lwork, "lwork")
2977}
2978
2979pub fn zgesvda_strided_batched_buffer_size(
2980 ctx: &Context,
2981 jobz: EigenMode,
2982 rank: usize,
2983 m: usize,
2984 n: usize,
2985 a: StridedBatchedMatrixRef<'_, Complex64>,
2986 s: StridedBatchedVectorRef<'_, f64>,
2987 u: Option<StridedBatchedMatrixRef<'_, Complex64>>,
2988 v: Option<StridedBatchedMatrixRef<'_, Complex64>>,
2989 batch_size: usize,
2990) -> Result<usize> {
2991 ctx.bind()?;
2992 validate_gesvda_strided_batched_inputs(
2993 rank,
2994 m,
2995 n,
2996 a.data.len(),
2997 a.leading_dimension,
2998 a.stride,
2999 s.data.len(),
3000 s.stride,
3001 jobz,
3002 strided_batched_matrix_ref_parts(u),
3003 strided_batched_matrix_ref_parts(v),
3004 batch_size,
3005 )?;
3006 let (u_ptr, ldu, stride_u) =
3007 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
3008 let (v_ptr, ldv, stride_v) =
3009 optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
3010 let mut lwork = 0;
3011 unsafe {
3012 try_ffi!(sys::cusolverDnZgesvdaStridedBatched_bufferSize(
3013 ctx.as_raw(),
3014 jobz.into(),
3015 to_i32(rank, "rank")?,
3016 to_i32(m, "m")?,
3017 to_i32(n, "n")?,
3018 a.data.as_ptr().cast(),
3019 to_i32(a.leading_dimension, "lda")?,
3020 to_i64(a.stride, "stride_a")?,
3021 s.data.as_ptr().cast(),
3022 to_i64(s.stride, "stride_s")?,
3023 u_ptr.cast(),
3024 ldu,
3025 stride_u,
3026 v_ptr.cast(),
3027 ldv,
3028 stride_v,
3029 &raw mut lwork,
3030 to_i32(batch_size, "batch_size")?,
3031 ))?;
3032 }
3033 to_usize(lwork, "lwork")
3034}
3035
3036pub fn sgesvda_strided_batched(
3100 ctx: &Context,
3101 jobz: EigenMode,
3102 rank: usize,
3103 m: usize,
3104 n: usize,
3105 a: StridedBatchedMatrixRef<'_, f32>,
3106 s: StridedBatchedVectorMut<'_, f32>,
3107 u: Option<StridedBatchedMatrixMut<'_, f32>>,
3108 v: Option<StridedBatchedMatrixMut<'_, f32>>,
3109 workspace: &mut DeviceMemory<f32>,
3110 dev_info: &mut DeviceMemory<i32>,
3111 residual: Option<&mut f64>,
3112 batch_size: usize,
3113) -> Result<()> {
3114 ctx.bind()?;
3115 validate_gesvda_strided_batched_inputs(
3116 rank,
3117 m,
3118 n,
3119 a.data.len(),
3120 a.leading_dimension,
3121 a.stride,
3122 s.data.len(),
3123 s.stride,
3124 jobz,
3125 strided_batched_matrix_mut_ref_option(u.as_ref())
3126 .map(|m| (m.data, m.leading_dimension, m.stride)),
3127 strided_batched_matrix_mut_ref_option(v.as_ref())
3128 .map(|m| (m.data, m.leading_dimension, m.stride)),
3129 batch_size,
3130 )?;
3131 require_info_buffer_len(dev_info, batch_size)?;
3132 let lwork = sgesvda_strided_batched_buffer_size(
3133 ctx,
3134 jobz,
3135 rank,
3136 m,
3137 n,
3138 a,
3139 s.as_ref(),
3140 strided_batched_matrix_mut_ref_option(u.as_ref()),
3141 strided_batched_matrix_mut_ref_option(v.as_ref()),
3142 batch_size,
3143 )?;
3144 require_workspace(workspace.len(), lwork)?;
3145 let (u_ptr, ldu, stride_u) =
3146 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3147 let (v_ptr, ldv, stride_v) =
3148 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3149 unsafe {
3150 try_ffi!(sys::cusolverDnSgesvdaStridedBatched(
3151 ctx.as_raw(),
3152 jobz.into(),
3153 to_i32(rank, "rank")?,
3154 to_i32(m, "m")?,
3155 to_i32(n, "n")?,
3156 a.data.as_ptr().cast(),
3157 to_i32(a.leading_dimension, "lda")?,
3158 to_i64(a.stride, "stride_a")?,
3159 s.data.as_mut_ptr().cast(),
3160 to_i64(s.stride, "stride_s")?,
3161 u_ptr.cast(),
3162 ldu,
3163 stride_u,
3164 v_ptr.cast(),
3165 ldv,
3166 stride_v,
3167 workspace.as_mut_ptr().cast(),
3168 to_i32(lwork, "lwork")?,
3169 dev_info.as_mut_ptr().cast(),
3170 residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3171 to_i32(batch_size, "batch_size")?,
3172 ))?;
3173 }
3174 Ok(())
3175}
3176
3177pub fn dgesvda_strided_batched(
3241 ctx: &Context,
3242 jobz: EigenMode,
3243 rank: usize,
3244 m: usize,
3245 n: usize,
3246 a: StridedBatchedMatrixRef<'_, f64>,
3247 s: StridedBatchedVectorMut<'_, f64>,
3248 u: Option<StridedBatchedMatrixMut<'_, f64>>,
3249 v: Option<StridedBatchedMatrixMut<'_, f64>>,
3250 workspace: &mut DeviceMemory<f64>,
3251 dev_info: &mut DeviceMemory<i32>,
3252 residual: Option<&mut f64>,
3253 batch_size: usize,
3254) -> Result<()> {
3255 ctx.bind()?;
3256 validate_gesvda_strided_batched_inputs(
3257 rank,
3258 m,
3259 n,
3260 a.data.len(),
3261 a.leading_dimension,
3262 a.stride,
3263 s.data.len(),
3264 s.stride,
3265 jobz,
3266 strided_batched_matrix_mut_ref_option(u.as_ref())
3267 .map(|m| (m.data, m.leading_dimension, m.stride)),
3268 strided_batched_matrix_mut_ref_option(v.as_ref())
3269 .map(|m| (m.data, m.leading_dimension, m.stride)),
3270 batch_size,
3271 )?;
3272 require_info_buffer_len(dev_info, batch_size)?;
3273 let lwork = dgesvda_strided_batched_buffer_size(
3274 ctx,
3275 jobz,
3276 rank,
3277 m,
3278 n,
3279 a,
3280 s.as_ref(),
3281 strided_batched_matrix_mut_ref_option(u.as_ref()),
3282 strided_batched_matrix_mut_ref_option(v.as_ref()),
3283 batch_size,
3284 )?;
3285 require_workspace(workspace.len(), lwork)?;
3286 let (u_ptr, ldu, stride_u) =
3287 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3288 let (v_ptr, ldv, stride_v) =
3289 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3290 unsafe {
3291 try_ffi!(sys::cusolverDnDgesvdaStridedBatched(
3292 ctx.as_raw(),
3293 jobz.into(),
3294 to_i32(rank, "rank")?,
3295 to_i32(m, "m")?,
3296 to_i32(n, "n")?,
3297 a.data.as_ptr().cast(),
3298 to_i32(a.leading_dimension, "lda")?,
3299 to_i64(a.stride, "stride_a")?,
3300 s.data.as_mut_ptr().cast(),
3301 to_i64(s.stride, "stride_s")?,
3302 u_ptr.cast(),
3303 ldu,
3304 stride_u,
3305 v_ptr.cast(),
3306 ldv,
3307 stride_v,
3308 workspace.as_mut_ptr().cast(),
3309 to_i32(lwork, "lwork")?,
3310 dev_info.as_mut_ptr().cast(),
3311 residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3312 to_i32(batch_size, "batch_size")?,
3313 ))?;
3314 }
3315 Ok(())
3316}
3317
3318pub fn cgesvda_strided_batched(
3382 ctx: &Context,
3383 jobz: EigenMode,
3384 rank: usize,
3385 m: usize,
3386 n: usize,
3387 a: StridedBatchedMatrixRef<'_, Complex32>,
3388 s: StridedBatchedVectorMut<'_, f32>,
3389 u: Option<StridedBatchedMatrixMut<'_, Complex32>>,
3390 v: Option<StridedBatchedMatrixMut<'_, Complex32>>,
3391 workspace: &mut DeviceMemory<Complex32>,
3392 dev_info: &mut DeviceMemory<i32>,
3393 residual: Option<&mut f64>,
3394 batch_size: usize,
3395) -> Result<()> {
3396 ctx.bind()?;
3397 validate_gesvda_strided_batched_inputs(
3398 rank,
3399 m,
3400 n,
3401 a.data.len(),
3402 a.leading_dimension,
3403 a.stride,
3404 s.data.len(),
3405 s.stride,
3406 jobz,
3407 strided_batched_matrix_mut_ref_option(u.as_ref())
3408 .map(|m| (m.data, m.leading_dimension, m.stride)),
3409 strided_batched_matrix_mut_ref_option(v.as_ref())
3410 .map(|m| (m.data, m.leading_dimension, m.stride)),
3411 batch_size,
3412 )?;
3413 require_info_buffer_len(dev_info, batch_size)?;
3414 let lwork = cgesvda_strided_batched_buffer_size(
3415 ctx,
3416 jobz,
3417 rank,
3418 m,
3419 n,
3420 a,
3421 s.as_ref(),
3422 strided_batched_matrix_mut_ref_option(u.as_ref()),
3423 strided_batched_matrix_mut_ref_option(v.as_ref()),
3424 batch_size,
3425 )?;
3426 require_workspace(workspace.len(), lwork)?;
3427 let (u_ptr, ldu, stride_u) =
3428 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3429 let (v_ptr, ldv, stride_v) =
3430 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3431 unsafe {
3432 try_ffi!(sys::cusolverDnCgesvdaStridedBatched(
3433 ctx.as_raw(),
3434 jobz.into(),
3435 to_i32(rank, "rank")?,
3436 to_i32(m, "m")?,
3437 to_i32(n, "n")?,
3438 a.data.as_ptr().cast(),
3439 to_i32(a.leading_dimension, "lda")?,
3440 to_i64(a.stride, "stride_a")?,
3441 s.data.as_mut_ptr().cast(),
3442 to_i64(s.stride, "stride_s")?,
3443 u_ptr.cast(),
3444 ldu,
3445 stride_u,
3446 v_ptr.cast(),
3447 ldv,
3448 stride_v,
3449 workspace.as_mut_ptr().cast(),
3450 to_i32(lwork, "lwork")?,
3451 dev_info.as_mut_ptr().cast(),
3452 residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3453 to_i32(batch_size, "batch_size")?,
3454 ))?;
3455 }
3456 Ok(())
3457}
3458
3459pub fn zgesvda_strided_batched(
3523 ctx: &Context,
3524 jobz: EigenMode,
3525 rank: usize,
3526 m: usize,
3527 n: usize,
3528 a: StridedBatchedMatrixRef<'_, Complex64>,
3529 s: StridedBatchedVectorMut<'_, f64>,
3530 u: Option<StridedBatchedMatrixMut<'_, Complex64>>,
3531 v: Option<StridedBatchedMatrixMut<'_, Complex64>>,
3532 workspace: &mut DeviceMemory<Complex64>,
3533 dev_info: &mut DeviceMemory<i32>,
3534 residual: Option<&mut f64>,
3535 batch_size: usize,
3536) -> Result<()> {
3537 ctx.bind()?;
3538 validate_gesvda_strided_batched_inputs(
3539 rank,
3540 m,
3541 n,
3542 a.data.len(),
3543 a.leading_dimension,
3544 a.stride,
3545 s.data.len(),
3546 s.stride,
3547 jobz,
3548 strided_batched_matrix_mut_ref_option(u.as_ref())
3549 .map(|m| (m.data, m.leading_dimension, m.stride)),
3550 strided_batched_matrix_mut_ref_option(v.as_ref())
3551 .map(|m| (m.data, m.leading_dimension, m.stride)),
3552 batch_size,
3553 )?;
3554 require_info_buffer_len(dev_info, batch_size)?;
3555 let lwork = zgesvda_strided_batched_buffer_size(
3556 ctx,
3557 jobz,
3558 rank,
3559 m,
3560 n,
3561 a,
3562 s.as_ref(),
3563 strided_batched_matrix_mut_ref_option(u.as_ref()),
3564 strided_batched_matrix_mut_ref_option(v.as_ref()),
3565 batch_size,
3566 )?;
3567 require_workspace(workspace.len(), lwork)?;
3568 let (u_ptr, ldu, stride_u) =
3569 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3570 let (v_ptr, ldv, stride_v) =
3571 optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3572 unsafe {
3573 try_ffi!(sys::cusolverDnZgesvdaStridedBatched(
3574 ctx.as_raw(),
3575 jobz.into(),
3576 to_i32(rank, "rank")?,
3577 to_i32(m, "m")?,
3578 to_i32(n, "n")?,
3579 a.data.as_ptr().cast(),
3580 to_i32(a.leading_dimension, "lda")?,
3581 to_i64(a.stride, "stride_a")?,
3582 s.data.as_mut_ptr().cast(),
3583 to_i64(s.stride, "stride_s")?,
3584 u_ptr.cast(),
3585 ldu,
3586 stride_u,
3587 v_ptr.cast(),
3588 ldv,
3589 stride_v,
3590 workspace.as_mut_ptr().cast(),
3591 to_i32(lwork, "lwork")?,
3592 dev_info.as_mut_ptr().cast(),
3593 residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3594 to_i32(batch_size, "batch_size")?,
3595 ))?;
3596 }
3597 Ok(())
3598}
3599
3600fn validate_gesvd_dims(m: usize, n: usize) -> Result<()> {
3601 if m == 0 || n == 0 || m < n {
3602 return Err(Error::InvalidMatrixShape);
3603 }
3604 Ok(())
3605}
3606
3607fn validate_xgesvdp_inputs<TU, TV>(
3608 m: usize,
3609 n: usize,
3610 a_bytes: usize,
3611 lda: usize,
3612 a_type: DataType,
3613 s_bytes: usize,
3614 s_type: DataType,
3615 jobz: EigenMode,
3616 econ: bool,
3617 u: Option<&(&DeviceMemory<TU>, usize)>,
3618 u_type: DataType,
3619 v: Option<&(&DeviceMemory<TV>, usize)>,
3620 v_type: DataType,
3621) -> Result<()> {
3622 if m == 0 || n == 0 {
3623 return Err(Error::InvalidMatrixShape);
3624 }
3625 validate_x_matrix(m, n, a_bytes, lda, a_type)?;
3626 validate_x_vector(m.min(n), s_bytes, s_type)?;
3627 match jobz {
3628 EigenMode::NoVector => Ok(()),
3629 EigenMode::Vector => {
3630 let Some((u, ldu)) = u else {
3631 return Err(Error::InvalidMatrixShape);
3632 };
3633 let Some((v, ldv)) = v else {
3634 return Err(Error::InvalidMatrixShape);
3635 };
3636 validate_x_eig_output(m, n, u.byte_len(), *ldu, econ, u_type)?;
3637 validate_x_eig_output(n, n, v.byte_len(), *ldv, econ, v_type)
3638 }
3639 }
3640}
3641
3642fn matrix_ref_parts<T>(matrix: Option<MatrixRef<'_, T>>) -> Option<(&DeviceMemory<T>, usize)> {
3643 matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
3644}
3645
3646fn matrix_mut_parts<T>(matrix: Option<MatrixMut<'_, T>>) -> Option<(&mut DeviceMemory<T>, usize)> {
3647 matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
3648}
3649
3650fn matrix_mut_ref_parts<'a, T>(
3651 matrix: Option<&'a MatrixMut<'a, T>>,
3652) -> Option<(&'a DeviceMemory<T>, usize)> {
3653 matrix.map(|matrix| (&*matrix.data, matrix.leading_dimension))
3654}
3655
3656fn matrix_mut_ref_option<'a, T>(matrix: Option<&'a MatrixMut<'a, T>>) -> Option<MatrixRef<'a, T>> {
3657 matrix.map(MatrixMut::as_ref)
3658}
3659
3660fn strided_batched_matrix_ref_parts<T>(
3661 matrix: Option<StridedBatchedMatrixRef<'_, T>>,
3662) -> Option<(&DeviceMemory<T>, usize, usize)> {
3663 matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
3664}
3665
3666fn strided_batched_matrix_mut_parts<T>(
3667 matrix: Option<StridedBatchedMatrixMut<'_, T>>,
3668) -> Option<(&mut DeviceMemory<T>, usize, usize)> {
3669 matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
3670}
3671
3672fn strided_batched_matrix_mut_ref_option<'a, T>(
3673 matrix: Option<&'a StridedBatchedMatrixMut<'a, T>>,
3674) -> Option<StridedBatchedMatrixRef<'a, T>> {
3675 matrix.map(StridedBatchedMatrixMut::as_ref)
3676}
3677
3678fn validate_xgesvdr_inputs<TU, TV>(
3679 m: usize,
3680 n: usize,
3681 k: usize,
3682 p: usize,
3683 niters: usize,
3684 a_bytes: usize,
3685 lda: usize,
3686 a_type: DataType,
3687 s_bytes: usize,
3688 s_type: DataType,
3689 job_u: TruncatedSvdMode,
3690 u: Option<&(&DeviceMemory<TU>, usize)>,
3691 u_type: DataType,
3692 job_v: TruncatedSvdMode,
3693 v: Option<&(&DeviceMemory<TV>, usize)>,
3694 v_type: DataType,
3695) -> Result<()> {
3696 if m == 0 || n == 0 || k == 0 || k >= m.min(n) || p == 0 || k.checked_add(p).is_none() {
3697 return Err(Error::InvalidMatrixShape);
3698 }
3699 let kp = k.checked_add(p).ok_or(Error::InvalidMatrixShape)?;
3700 if kp >= m.min(n) || niters == 0 {
3701 return Err(Error::InvalidMatrixShape);
3702 }
3703 validate_x_matrix(m, n, a_bytes, lda, a_type)?;
3704 validate_x_vector(k, s_bytes, s_type)?;
3705 if matches!(job_u, TruncatedSvdMode::Some) {
3706 let Some((u, ldu)) = u else {
3707 return Err(Error::InvalidMatrixShape);
3708 };
3709 validate_x_matrix(m, k, u.byte_len(), *ldu, u_type)?;
3710 }
3711 if matches!(job_v, TruncatedSvdMode::Some) {
3712 let Some((v, ldv)) = v else {
3713 return Err(Error::InvalidMatrixShape);
3714 };
3715 validate_x_matrix(n, k, v.byte_len(), *ldv, v_type)?;
3716 }
3717 Ok(())
3718}
3719
3720fn validate_gesvdj_inputs<T>(
3721 m: usize,
3722 n: usize,
3723 a_len: usize,
3724 lda: usize,
3725 s_len: usize,
3726 jobz: EigenMode,
3727 econ: bool,
3728 u: Option<(&DeviceMemory<T>, usize)>,
3729 v: Option<(&DeviceMemory<T>, usize)>,
3730) -> Result<()> {
3731 if m == 0 || n == 0 {
3732 return Err(Error::InvalidMatrixShape);
3733 }
3734 validate_matrix(m, n, a_len, lda)?;
3735 if s_len < m.min(n) {
3736 return Err(Error::InvalidVectorShape);
3737 }
3738 validate_gesvdj_output(m, n, jobz, econ, u)?;
3739 validate_gesvdj_output(n, n, jobz, econ, v)?;
3740 Ok(())
3741}
3742
3743fn validate_gesvda_strided_batched_inputs<T>(
3744 rank: usize,
3745 m: usize,
3746 n: usize,
3747 a_len: usize,
3748 lda: usize,
3749 stride_a: usize,
3750 s_len: usize,
3751 stride_s: usize,
3752 jobz: EigenMode,
3753 u: Option<(&DeviceMemory<T>, usize, usize)>,
3754 v: Option<(&DeviceMemory<T>, usize, usize)>,
3755 batch_size: usize,
3756) -> Result<()> {
3757 if batch_size == 0 || m == 0 || n == 0 || m < n || rank == 0 || rank > n {
3758 return Err(Error::InvalidMatrixShape);
3759 }
3760
3761 validate_strided_matrix(m, n, a_len, lda, stride_a, batch_size)?;
3762 validate_strided_vector(s_len, n, stride_s, batch_size)?;
3763
3764 match jobz {
3765 EigenMode::NoVector => {}
3766 EigenMode::Vector => {
3767 let Some((u, ldu, stride_u)) = u else {
3768 return Err(Error::InvalidMatrixShape);
3769 };
3770 let Some((v, ldv, stride_v)) = v else {
3771 return Err(Error::InvalidMatrixShape);
3772 };
3773 validate_strided_matrix(m, rank, u.len(), ldu, stride_u, batch_size)?;
3774 validate_strided_matrix(n, rank, v.len(), ldv, stride_v, batch_size)?;
3775 }
3776 }
3777 Ok(())
3778}
3779
3780fn validate_gesvdj_batched_inputs<T>(
3781 m: usize,
3782 n: usize,
3783 a_len: usize,
3784 lda: usize,
3785 s_len: usize,
3786 jobz: EigenMode,
3787 u: Option<(&DeviceMemory<T>, usize)>,
3788 v: Option<(&DeviceMemory<T>, usize)>,
3789 batch_size: usize,
3790) -> Result<()> {
3791 if batch_size == 0 || m == 0 || n == 0 || m > 32 || n > 32 {
3792 return Err(Error::InvalidMatrixShape);
3793 }
3794
3795 let a_cols = n.checked_mul(batch_size).ok_or(Error::InvalidMatrixShape)?;
3796 validate_matrix(m, a_cols, a_len, lda)?;
3797
3798 let s_required = m
3799 .min(n)
3800 .checked_mul(batch_size)
3801 .ok_or(Error::InvalidVectorShape)?;
3802 if s_len < s_required {
3803 return Err(Error::InvalidVectorShape);
3804 }
3805
3806 validate_gesvdj_batched_output(m, n, jobz, u, batch_size)?;
3807 validate_gesvdj_batched_output(n, n, jobz, v, batch_size)?;
3808 Ok(())
3809}
3810
3811fn validate_gesvdj_output<T>(
3812 rows: usize,
3813 cols: usize,
3814 jobz: EigenMode,
3815 econ: bool,
3816 matrix: Option<(&DeviceMemory<T>, usize)>,
3817) -> Result<()> {
3818 match jobz {
3819 EigenMode::NoVector => Ok(()),
3820 EigenMode::Vector => {
3821 let Some((matrix, ld)) = matrix else {
3822 return Err(Error::InvalidMatrixShape);
3823 };
3824 let out_cols = if econ { rows.min(cols) } else { cols };
3825 validate_matrix(rows, out_cols, matrix.len(), ld)
3826 }
3827 }
3828}
3829
3830fn validate_gesvdj_batched_output<T>(
3831 rows: usize,
3832 cols: usize,
3833 jobz: EigenMode,
3834 matrix: Option<(&DeviceMemory<T>, usize)>,
3835 batch_size: usize,
3836) -> Result<()> {
3837 match jobz {
3838 EigenMode::NoVector => Ok(()),
3839 EigenMode::Vector => {
3840 let Some((matrix, ld)) = matrix else {
3841 return Err(Error::InvalidMatrixShape);
3842 };
3843 let out_cols = rows
3844 .min(cols)
3845 .checked_mul(batch_size)
3846 .ok_or(Error::InvalidMatrixShape)?;
3847 validate_matrix(rows, out_cols, matrix.len(), ld)
3848 }
3849 }
3850}
3851
3852fn optional_gesvda_output_ptr<T>(
3853 matrix: Option<(&DeviceMemory<T>, usize, usize)>,
3854 rows: usize,
3855 cols: usize,
3856 jobz: EigenMode,
3857) -> Result<(*mut T, i32, i64)> {
3858 match jobz {
3859 EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
3860 EigenMode::Vector => {
3861 let Some((matrix, ld, stride)) = matrix else {
3862 return Err(Error::InvalidMatrixShape);
3863 };
3864 validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
3865 Ok((
3866 matrix.as_ptr() as *mut T,
3867 to_i32(ld, "ld")?,
3868 to_i64(stride, "stride")?,
3869 ))
3870 }
3871 }
3872}
3873
3874fn optional_gesvda_output_mut_ptr<T>(
3875 matrix: Option<(&mut DeviceMemory<T>, usize, usize)>,
3876 rows: usize,
3877 cols: usize,
3878 jobz: EigenMode,
3879) -> Result<(*mut T, i32, i64)> {
3880 match jobz {
3881 EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
3882 EigenMode::Vector => {
3883 let Some((matrix, ld, stride)) = matrix else {
3884 return Err(Error::InvalidMatrixShape);
3885 };
3886 validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
3887 Ok((
3888 matrix.as_mut_ptr().cast(),
3889 to_i32(ld, "ld")?,
3890 to_i64(stride, "stride")?,
3891 ))
3892 }
3893 }
3894}
3895
3896fn validate_gesvd_inputs<T>(
3897 m: usize,
3898 n: usize,
3899 a_len: usize,
3900 lda: usize,
3901 s_len: usize,
3902 job_u: SvdMode,
3903 u: Option<&(&DeviceMemory<T>, usize)>,
3904 job_vt: SvdMode,
3905 vt: Option<&(&DeviceMemory<T>, usize)>,
3906) -> Result<()> {
3907 validate_gesvd_dims(m, n)?;
3908 validate_matrix(m, n, a_len, lda)?;
3909 if s_len < n {
3910 return Err(Error::InvalidVectorShape);
3911 }
3912 validate_svd_output(m, m, job_u, u)?;
3913 validate_svd_output(n, n, job_vt, vt)?;
3914 Ok(())
3915}
3916
3917fn validate_x_svd_output<T>(
3918 rows: usize,
3919 full_cols: usize,
3920 matrix: Option<(&DeviceMemory<T>, usize)>,
3921 mode: SvdMode,
3922 data_type: DataType,
3923) -> Result<()> {
3924 match mode {
3925 SvdMode::None | SvdMode::Overwrite => Ok(()),
3926 SvdMode::All => {
3927 let Some((matrix, ld)) = matrix else {
3928 return Err(Error::InvalidMatrixShape);
3929 };
3930 validate_x_matrix(rows, full_cols, matrix.byte_len(), ld, data_type)
3931 }
3932 SvdMode::Some => {
3933 let Some((matrix, ld)) = matrix else {
3934 return Err(Error::InvalidMatrixShape);
3935 };
3936 validate_x_matrix(rows, full_cols.min(rows), matrix.byte_len(), ld, data_type)
3937 }
3938 }
3939}
3940
3941fn validate_svd_output<T>(
3942 rows: usize,
3943 full_cols: usize,
3944 mode: SvdMode,
3945 matrix: Option<&(&DeviceMemory<T>, usize)>,
3946) -> Result<()> {
3947 match mode {
3948 SvdMode::None | SvdMode::Overwrite => Ok(()),
3949 SvdMode::All => {
3950 let Some((matrix, ld)) = matrix else {
3951 return Err(Error::InvalidMatrixShape);
3952 };
3953 validate_matrix(rows, full_cols, matrix.len(), *ld)
3954 }
3955 SvdMode::Some => {
3956 let Some((matrix, ld)) = matrix else {
3957 return Err(Error::InvalidMatrixShape);
3958 };
3959 validate_matrix(rows, full_cols.min(rows), matrix.len(), *ld)
3960 }
3961 }
3962}
3963
3964fn validate_x_eig_output(
3965 rows: usize,
3966 cols: usize,
3967 bytes: usize,
3968 ld: usize,
3969 econ: bool,
3970 data_type: DataType,
3971) -> Result<()> {
3972 let out_cols = if econ { rows.min(cols) } else { cols };
3973 validate_x_matrix(rows, out_cols, bytes, ld, data_type)
3974}
3975
3976fn optional_matrix_ptr<T>(
3977 matrix: Option<(&mut DeviceMemory<T>, usize)>,
3978 order: usize,
3979 mode: SvdMode,
3980) -> Result<(*mut T, i32)> {
3981 match mode {
3982 SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), to_i32(order.max(1), "ld")?)),
3983 SvdMode::All | SvdMode::Some => {
3984 let Some((matrix, ld)) = matrix else {
3985 return Err(Error::InvalidMatrixShape);
3986 };
3987 Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
3988 }
3989 }
3990}
3991
3992fn optional_x_matrix_ptr<T>(
3993 matrix: Option<(&DeviceMemory<T>, usize)>,
3994 rows: usize,
3995 cols: usize,
3996 mode: SvdMode,
3997 data_type: DataType,
3998) -> Result<(*mut T, i64)> {
3999 match mode {
4000 SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
4001 SvdMode::All => {
4002 let Some((matrix, ld)) = matrix else {
4003 return Err(Error::InvalidMatrixShape);
4004 };
4005 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4006 Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4007 }
4008 SvdMode::Some => {
4009 let Some((matrix, ld)) = matrix else {
4010 return Err(Error::InvalidMatrixShape);
4011 };
4012 validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
4013 Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4014 }
4015 }
4016}
4017
4018fn optional_x_matrix_mut_ptr<T>(
4019 matrix: Option<(&mut DeviceMemory<T>, usize)>,
4020 rows: usize,
4021 cols: usize,
4022 mode: SvdMode,
4023 data_type: DataType,
4024) -> Result<(*mut T, i64)> {
4025 match mode {
4026 SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
4027 SvdMode::All => {
4028 let Some((matrix, ld)) = matrix else {
4029 return Err(Error::InvalidMatrixShape);
4030 };
4031 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4032 Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4033 }
4034 SvdMode::Some => {
4035 let Some((matrix, ld)) = matrix else {
4036 return Err(Error::InvalidMatrixShape);
4037 };
4038 validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
4039 Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4040 }
4041 }
4042}
4043
4044fn optional_x_eig_matrix_ptr<T>(
4045 matrix: Option<(&DeviceMemory<T>, usize)>,
4046 rows: usize,
4047 cols: usize,
4048 jobz: EigenMode,
4049 econ: bool,
4050 data_type: DataType,
4051) -> Result<(*mut T, i64)> {
4052 match jobz {
4053 EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4054 EigenMode::Vector => {
4055 let Some((matrix, ld)) = matrix else {
4056 return Err(Error::InvalidMatrixShape);
4057 };
4058 validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
4059 Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4060 }
4061 }
4062}
4063
4064fn optional_x_eig_matrix_mut_ptr<T>(
4065 matrix: Option<(&mut DeviceMemory<T>, usize)>,
4066 rows: usize,
4067 cols: usize,
4068 jobz: EigenMode,
4069 econ: bool,
4070 data_type: DataType,
4071) -> Result<(*mut T, i64)> {
4072 match jobz {
4073 EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4074 EigenMode::Vector => {
4075 let Some((matrix, ld)) = matrix else {
4076 return Err(Error::InvalidMatrixShape);
4077 };
4078 validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
4079 Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4080 }
4081 }
4082}
4083
4084fn optional_x_truncated_u_ptr<T>(
4085 matrix: Option<(&DeviceMemory<T>, usize)>,
4086 rows: usize,
4087 cols: usize,
4088 mode: TruncatedSvdMode,
4089 data_type: DataType,
4090) -> Result<(*mut T, i64)> {
4091 match mode {
4092 TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4093 TruncatedSvdMode::Some => {
4094 let Some((matrix, ld)) = matrix else {
4095 return Err(Error::InvalidMatrixShape);
4096 };
4097 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4098 Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4099 }
4100 }
4101}
4102
4103fn optional_x_truncated_u_mut_ptr<T>(
4104 matrix: Option<(&mut DeviceMemory<T>, usize)>,
4105 rows: usize,
4106 cols: usize,
4107 mode: TruncatedSvdMode,
4108 data_type: DataType,
4109) -> Result<(*mut T, i64)> {
4110 match mode {
4111 TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4112 TruncatedSvdMode::Some => {
4113 let Some((matrix, ld)) = matrix else {
4114 return Err(Error::InvalidMatrixShape);
4115 };
4116 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4117 Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4118 }
4119 }
4120}
4121
4122fn optional_x_truncated_v_ptr<T>(
4123 matrix: Option<(&DeviceMemory<T>, usize)>,
4124 rows: usize,
4125 cols: usize,
4126 mode: TruncatedSvdMode,
4127 data_type: DataType,
4128) -> Result<(*mut T, i64)> {
4129 match mode {
4130 TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4131 TruncatedSvdMode::Some => {
4132 let Some((matrix, ld)) = matrix else {
4133 return Err(Error::InvalidMatrixShape);
4134 };
4135 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4136 Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4137 }
4138 }
4139}
4140
4141fn optional_x_truncated_v_mut_ptr<T>(
4142 matrix: Option<(&mut DeviceMemory<T>, usize)>,
4143 rows: usize,
4144 cols: usize,
4145 mode: TruncatedSvdMode,
4146 data_type: DataType,
4147) -> Result<(*mut T, i64)> {
4148 match mode {
4149 TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4150 TruncatedSvdMode::Some => {
4151 let Some((matrix, ld)) = matrix else {
4152 return Err(Error::InvalidMatrixShape);
4153 };
4154 validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4155 Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4156 }
4157 }
4158}
4159
4160fn optional_gesvdj_matrix_ptr<T>(
4161 matrix: Option<(&DeviceMemory<T>, usize)>,
4162 rows: usize,
4163 cols: usize,
4164 jobz: EigenMode,
4165 econ: bool,
4166) -> Result<(*mut T, i32)> {
4167 match jobz {
4168 EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4169 EigenMode::Vector => {
4170 let Some((matrix, ld)) = matrix else {
4171 return Err(Error::InvalidMatrixShape);
4172 };
4173 let out_cols = if econ { rows.min(cols) } else { cols };
4174 validate_matrix(rows, out_cols, matrix.len(), ld)?;
4175 Ok((matrix.as_ptr() as *mut T, to_i32(ld, "ld")?))
4176 }
4177 }
4178}
4179
4180fn optional_gesvdj_matrix_mut_ptr<T>(
4181 matrix: Option<(&mut DeviceMemory<T>, usize)>,
4182 rows: usize,
4183 cols: usize,
4184 jobz: EigenMode,
4185 econ: bool,
4186) -> Result<(*mut T, i32)> {
4187 match jobz {
4188 EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4189 EigenMode::Vector => {
4190 let Some((matrix, ld)) = matrix else {
4191 return Err(Error::InvalidMatrixShape);
4192 };
4193 let out_cols = if econ { rows.min(cols) } else { cols };
4194 validate_matrix(rows, out_cols, matrix.len(), ld)?;
4195 Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
4196 }
4197 }
4198}
4199
4200fn require_rwork_buffer<T>(rwork: Option<&DeviceMemory<T>>, m: usize, n: usize) -> Result<()> {
4201 let required = n.saturating_sub(1).min(m.saturating_sub(1));
4202 if let Some(rwork) = rwork
4203 && rwork.len() < required
4204 {
4205 return Err(Error::InvalidVectorShape);
4206 }
4207 Ok(())
4208}
4209
4210fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
4211 if rows == 0 || cols == 0 {
4212 return Err(Error::InvalidMatrixShape);
4213 }
4214 if lda < rows {
4215 return Err(Error::InvalidLeadingDimension);
4216 }
4217 let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
4218 if len < required {
4219 return Err(Error::InvalidMatrixShape);
4220 }
4221 Ok(())
4222}
4223
4224fn validate_x_matrix(
4225 rows: usize,
4226 cols: usize,
4227 bytes: usize,
4228 lda: usize,
4229 data_type: DataType,
4230) -> Result<()> {
4231 if rows == 0 || cols == 0 {
4232 return Err(Error::InvalidMatrixShape);
4233 }
4234 if lda < rows {
4235 return Err(Error::InvalidLeadingDimension);
4236 }
4237 let elem_size = data_type.size_in_bytes();
4238 let required = lda
4239 .checked_mul(cols)
4240 .and_then(|count| count.checked_mul(elem_size))
4241 .ok_or(Error::InvalidMatrixShape)?;
4242 if bytes < required {
4243 return Err(Error::InvalidMatrixShape);
4244 }
4245 Ok(())
4246}
4247
4248fn validate_x_vector(len: usize, bytes: usize, data_type: DataType) -> Result<()> {
4249 let required = len
4250 .checked_mul(data_type.size_in_bytes())
4251 .ok_or(Error::InvalidVectorShape)?;
4252 if bytes < required {
4253 return Err(Error::InvalidVectorShape);
4254 }
4255 Ok(())
4256}
4257
4258fn validate_strided_matrix(
4259 rows: usize,
4260 cols: usize,
4261 len: usize,
4262 lda: usize,
4263 stride: usize,
4264 batch_size: usize,
4265) -> Result<()> {
4266 validate_matrix(rows, cols, len, lda)?;
4267 if batch_size == 0 {
4268 return Err(Error::InvalidMatrixShape);
4269 }
4270 let footprint = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
4271 if stride < footprint {
4272 return Err(Error::InvalidMatrixShape);
4273 }
4274 let required = if batch_size == 1 {
4275 footprint
4276 } else {
4277 stride
4278 .checked_mul(batch_size - 1)
4279 .and_then(|base| base.checked_add(footprint))
4280 .ok_or(Error::InvalidMatrixShape)?
4281 };
4282 if len < required {
4283 return Err(Error::InvalidMatrixShape);
4284 }
4285 Ok(())
4286}
4287
4288fn validate_strided_vector(
4289 len: usize,
4290 width: usize,
4291 stride: usize,
4292 batch_size: usize,
4293) -> Result<()> {
4294 if width == 0 || batch_size == 0 {
4295 return Err(Error::InvalidVectorShape);
4296 }
4297 if stride < width {
4298 return Err(Error::InvalidVectorShape);
4299 }
4300 let required = if batch_size == 1 {
4301 width
4302 } else {
4303 stride
4304 .checked_mul(batch_size - 1)
4305 .and_then(|base| base.checked_add(width))
4306 .ok_or(Error::InvalidVectorShape)?
4307 };
4308 if len < required {
4309 return Err(Error::InvalidVectorShape);
4310 }
4311 Ok(())
4312}
4313
4314fn require_workspace(actual: usize, required: usize) -> Result<()> {
4315 if actual < required {
4316 return Err(Error::InsufficientWorkspaceSize { required, actual });
4317 }
4318 Ok(())
4319}
4320
4321fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
4322 if actual < required {
4323 return Err(Error::InsufficientWorkspaceSize { required, actual });
4324 }
4325 Ok(())
4326}
4327
4328fn require_host_workspace(actual: usize, required: usize) -> Result<()> {
4329 if actual < required {
4330 return Err(Error::InsufficientWorkspaceSize { required, actual });
4331 }
4332 Ok(())
4333}
4334
4335fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
4336 if dev_info.is_empty() {
4337 return Err(Error::InvalidVectorShape);
4338 }
4339 Ok(())
4340}
4341
4342fn require_info_buffer_len(dev_info: &DeviceMemory<i32>, required: usize) -> Result<()> {
4343 if dev_info.len() < required {
4344 return Err(Error::InvalidVectorShape);
4345 }
4346 Ok(())
4347}
4348
4349#[cfg(all(test, feature = "testing"))]
4350mod tests {
4351 use singe_core::assert_close;
4352 use singe_cuda::memory::DeviceMemory;
4353
4354 use super::*;
4355 use crate::{params::Params, testing::setup_context_if_available};
4356
4357 #[test]
4358 fn test_sgesvd_returns_expected_singular_values() -> Result<()> {
4359 let Some(ctx) = setup_context_if_available()? else {
4360 return Ok(());
4361 };
4362
4363 let mut a = DeviceMemory::from_slice(&[
4364 3.0_f32, 0.0, 0.0, 2.0,
4366 ])?;
4367 let mut s = DeviceMemory::create(2)?;
4368 let mut workspace = DeviceMemory::create(sgesvd_buffer_size(&ctx, 2, 2)?)?;
4369 let mut dev_info = DeviceMemory::create(1)?;
4370
4371 sgesvd(
4372 &ctx,
4373 SvdMode::None,
4374 SvdMode::None,
4375 2,
4376 2,
4377 MatrixMut::new(&mut a, 2),
4378 &mut s,
4379 None,
4380 None,
4381 &mut workspace,
4382 None,
4383 &mut dev_info,
4384 )?;
4385
4386 let singular_values = s.copy_to_host_vec()?;
4387 let info = dev_info.copy_to_host_vec()?;
4388
4389 assert_eq!(info, vec![0]);
4390 assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
4391 Ok(())
4392 }
4393
4394 #[test]
4395 fn test_xgesvd_returns_expected_singular_values() -> Result<()> {
4396 let Some(ctx) = setup_context_if_available()? else {
4397 return Ok(());
4398 };
4399 let params = Params::create()?;
4400
4401 let mut a = DeviceMemory::from_slice(&[
4402 3.0_f32, 0.0, 0.0, 2.0,
4404 ])?;
4405 let mut s = DeviceMemory::create(2)?;
4406 let workspace_sizes = xgesvd_buffer_size::<f32, f32, f32, f32>(
4407 &ctx,
4408 ¶ms,
4409 SvdMode::None,
4410 SvdMode::None,
4411 2,
4412 2,
4413 MatrixRef::new(&a, 2),
4414 &s,
4415 None,
4416 None,
4417 )?;
4418 let mut device_workspace = DeviceMemory::create(workspace_sizes.device_bytes.max(1))?;
4419 let mut host_workspace = vec![0_u8; workspace_sizes.host_bytes.max(1)];
4420 let mut dev_info = DeviceMemory::create(1)?;
4421
4422 xgesvd::<f32, f32, f32, f32>(
4423 &ctx,
4424 ¶ms,
4425 SvdMode::None,
4426 SvdMode::None,
4427 2,
4428 2,
4429 MatrixMut::new(&mut a, 2),
4430 &mut s,
4431 None,
4432 None,
4433 ByteWorkspaceMut::new(&mut device_workspace, &mut host_workspace),
4434 &mut dev_info,
4435 )?;
4436
4437 let singular_values = s.copy_to_host_vec()?;
4438 let info = dev_info.copy_to_host_vec()?;
4439
4440 assert_eq!(info, vec![0]);
4441 assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
4442 Ok(())
4443 }
4444}