1#![allow(deprecated)]
2#![allow(clippy::new_without_default)]
3#![allow(clippy::needless_return)]
4#![allow(clippy::manual_slice_size_calculation)]
5#![allow(clippy::unwrap_or_default)]
6#![allow(clippy::single_char_add_str)]
7#![allow(clippy::needless_borrow)]
8#![allow(clippy::manual_is_multiple_of)]
9#![allow(clippy::extend_with_drain)]
10#![allow(clippy::vec_init_then_push)]
11#![allow(clippy::match_like_matches_macro)]
12#![allow(clippy::manual_clamp)]
13#![allow(clippy::for_kv_map)]
14#![allow(clippy::derivable_impls)]
15pub mod error;
222pub use error::{LinalgError, LinalgResult};
223
224pub mod attention;
226mod basic;
227pub mod batch;
228pub mod broadcast;
229pub mod complex;
230pub mod convolution;
231mod decomposition;
232pub mod decomposition_advanced;
233pub mod eigen;
235pub use self::eigen::{
236 advanced_precision_eig, eig, eig_gen, eigh, eigh_gen, eigvals, eigvals_gen, eigvalsh,
237 eigvalsh_gen, power_iteration,
238};
239
240pub mod eigen_specialized;
242pub mod extended_precision;
243pub mod generic;
244pub mod gradient;
245pub mod hierarchical;
246mod iterative_solvers;
247pub mod kronecker;
248pub mod large_scale;
249pub mod lowrank;
250pub mod matrix_calculus;
251pub mod matrix_dynamics;
252pub mod matrix_equations;
253pub mod matrix_factorization;
254pub mod matrix_functions;
255pub mod matrixfree;
256pub mod mixed_precision;
257mod norm;
258pub mod optim;
259pub mod parallel;
260pub mod parallel_dispatch;
261pub mod perf_opt;
262pub mod preconditioners;
263pub mod projection;
264pub mod quantization;
266pub use self::quantization::calibration::{
268 calibrate_matrix, calibrate_vector, get_activation_calibration_config,
269 get_weight_calibration_config, CalibrationConfig, CalibrationMethod,
270};
271pub mod random;
272pub mod random_matrices;
273pub mod circulant_toeplitz;
276mod diagnostics;
277pub mod fft;
278pub mod scalable;
279pub mod simd_ops;
280mod solve;
281pub mod solvers;
282pub mod sparse_dense;
283pub mod special;
284pub mod specialized;
285pub mod stats;
286pub mod structured;
287#[cfg(feature = "tensor_contraction")]
288pub mod tensor_contraction;
289pub mod tensor_train;
290mod validation;
291#[cfg(any(
296 feature = "cuda",
297 feature = "opencl",
298 feature = "rocm",
299 feature = "metal"
300))]
301pub mod gpu;
302
303#[cfg(feature = "autograd")]
305pub mod autograd;
306
307pub mod compat;
309pub mod compat_wrappers;
310
311pub mod blas_accelerated;
313pub mod lapack_accelerated;
314
315pub mod blas;
317pub mod lapack;
318
319pub mod accelerated {
321 pub use super::blas_accelerated::*;
329 pub use super::lapack_accelerated::*;
330}
331
332pub use self::basic::{det, inv, matrix_power, trace as basic_trace};
334pub use self::eigen_specialized::{
335 banded_eigen, banded_eigh, banded_eigvalsh, circulant_eigenvalues, largest_k_eigh,
336 partial_eigen, smallest_k_eigh, tridiagonal_eigen, tridiagonal_eigh, tridiagonal_eigvalsh,
337};
338pub use self::complex::enhanced_ops::{
340 det as complex_det, frobenius_norm, hermitian_part, inner_product, is_hermitian, is_unitary,
341 matrix_exp, matvec, polar_decomposition, power_method, rank as complex_rank,
342 schur as complex_schur, skew_hermitian_part, trace,
343};
344pub use self::complex::{complex_inverse, complex_matmul, hermitian_transpose};
345pub use self::decomposition::{cholesky, lu, qr, schur, svd};
347pub use self::decomposition::{cholesky_default, lu_default, qr_default, svd_default};
349pub use self::decomposition_advanced::{
351 jacobi_svd, polar_decomposition as advanced_polar_decomposition, polar_decomposition_newton,
352 qr_with_column_pivoting,
353};
354pub use self::basic::{det_default, inv_default, matrix_power_default};
356pub use self::iterative_solvers::conjugate_gradient_default;
358pub use self::extended_precision::*;
360pub use self::iterative_solvers::*;
361pub use self::matrix_equations::{
363 solve_continuous_riccati, solve_discrete_riccati, solve_generalized_sylvester, solve_stein,
364 solve_sylvester,
365};
366pub use self::matrix_factorization::{
367 cur_decomposition, interpolative_decomposition, nmf, rank_revealing_qr, utv_decomposition,
368};
369pub use self::matrix_functions::{
370 acosm, asinm, atanm, coshm, cosm, expm, geometric_mean_spd, logm, logm_parallel, nuclear_norm,
371 signm, sinhm, sinm, spectral_condition_number, spectral_radius, sqrtm, sqrtm_parallel, tanhm,
372 tanm, tikhonov_regularization,
373};
374pub use self::matrixfree::{
375 block_diagonal_operator, conjugate_gradient as matrix_free_conjugate_gradient,
376 diagonal_operator, gmres as matrix_free_gmres, jacobi_preconditioner,
377 preconditioned_conjugate_gradient as matrix_free_preconditioned_conjugate_gradient,
378 LinearOperator, MatrixFreeOp,
379};
380pub use self::norm::*;
381pub use self::solve::{lstsq, solve, solve_multiple, solve_triangular, LstsqResult};
383pub use self::solve::{lstsq_default, solve_default, solve_multiple_default};
385pub use self::solvers::iterative::{
387 bicgstab, conjugate_gradient as cg_solver, gmres,
388 preconditioned_conjugate_gradient as pcg_solver, IterativeSolverOptions, IterativeSolverResult,
389};
390pub use self::specialized::{
391 specialized_to_operator, BandedMatrix, SpecializedMatrix, SymmetricMatrix, TridiagonalMatrix,
392};
393pub use self::stats::*;
394pub use self::structured::{
395 structured_to_operator, CirculantMatrix, HankelMatrix, StructuredMatrix, ToeplitzMatrix,
396};
397#[cfg(feature = "tensor_contraction")]
398pub use self::tensor_contraction::{batch_matmul, contract, einsum, hosvd};
399
400pub mod prelude {
402 pub use super::attention::{
410 attention, attention_with_alibi, attention_with_rpe, causal_attention, flash_attention,
411 grouped_query_attention, linear_attention, masked_attention, multi_head_attention,
412 relative_position_attention, rotary_embedding, scaled_dot_product_attention,
413 sparse_attention, AttentionConfig, AttentionMask,
414 };
415 pub use super::basic::{det, inv};
416 pub use super::batch::attention::{
417 batch_flash_attention, batch_multi_head_attention, batch_multi_query_attention,
418 };
419 pub use super::broadcast::{
420 broadcast_matmul, broadcast_matmul_3d, broadcast_matvec, BroadcastExt,
421 };
422 pub use super::complex::enhanced_ops::{
423 det as complex_det, frobenius_norm as complex_frobenius_norm, hermitian_part,
424 inner_product as complex_inner_product, is_hermitian, is_unitary,
425 matrix_exp as complex_exp, matvec as complex_matvec, polar_decomposition as complex_polar,
426 schur as complex_schur, skew_hermitian_part,
427 };
428 pub use super::complex::{
429 complex_inverse, complex_matmul, complex_norm_frobenius, hermitian_transpose,
430 };
431 pub use super::convolution::{
432 col2im, compute_conv_indices, conv2d_backward_bias, conv2d_backward_input,
433 conv2d_backward_kernel, conv2d_im2col, conv_transpose2d, im2col, max_pool2d,
434 max_pool2d_backward,
435 };
436 pub use super::decomposition::{cholesky, lu, qr, schur, svd};
437 pub use super::decomposition_advanced::{
438 jacobi_svd, polar_decomposition as advanced_polar_decomposition,
439 polar_decomposition_newton, qr_with_column_pivoting,
440 };
441 pub use super::eigen::{
442 advanced_precision_eig, eig, eig_gen, eigh, eigh_gen, eigvals, eigvals_gen, eigvalsh,
443 eigvalsh_gen, power_iteration,
444 };
445 pub use super::eigen_specialized::{
446 banded_eigen, banded_eigh, banded_eigvalsh, circulant_eigenvalues, largest_k_eigh,
447 partial_eigen, smallest_k_eigh, tridiagonal_eigen, tridiagonal_eigh, tridiagonal_eigvalsh,
448 };
449 pub use super::extended_precision::eigen::{
450 extended_eig, extended_eigh, extended_eigvals, extended_eigvalsh,
451 };
452 pub use super::extended_precision::factorizations::{
453 extended_cholesky, extended_lu, extended_qr, extended_svd,
454 };
455 pub use super::extended_precision::{
456 extended_det, extended_matmul, extended_matvec, extended_solve,
457 };
458 pub use super::hierarchical::{
459 adaptive_block_lowrank, build_cluster_tree, BlockType, ClusterNode, HMatrix, HMatrixBlock,
460 HMatrixMemoryInfo, HSSMatrix, HSSNode,
461 };
462 pub use super::iterative_solvers::{
463 bicgstab, conjugate_gradient, gauss_seidel, geometric_multigrid, jacobi_method, minres,
464 successive_over_relaxation,
465 };
466 pub use super::kronecker::{
467 advanced_kfac_step, kfac_factorization, kfac_update, kron, kron_factorize, kron_matmul,
468 kron_matvec, BlockDiagonalFisher, BlockFisherMemoryInfo, KFACOptimizer,
469 };
470 pub use super::large_scale::{
471 block_krylov_solve, ca_gmres, incremental_svd, randomized_block_lanczos,
472 randomized_least_squares, randomized_norm,
473 };
474 pub use super::lowrank::{
475 cur_decomposition, nmf as lowrank_nmf, pca, randomized_svd, truncated_svd,
476 };
477 pub use super::solvers::iterative::{
478 bicgstab as iterative_bicgstab, conjugate_gradient as iterative_cg,
479 gmres as iterative_gmres, preconditioned_conjugate_gradient as iterative_pcg,
480 IterativeSolverOptions, IterativeSolverResult,
481 };
482 pub use super::matrix_dynamics::{
489 lyapunov_solve, matrix_exp_action, matrix_ode_solve, quantum_evolution, riccati_solve,
490 stability_analysis, DynamicsConfig, ODEResult,
491 };
492 pub use super::matrix_factorization::{
493 interpolative_decomposition, nmf, rank_revealing_qr, utv_decomposition,
494 };
495 pub use super::matrix_functions::{
496 acosm, asinm, atanm, coshm, cosm, expm, geometric_mean_spd, logm, logm_parallel,
497 matrix_power, nuclear_norm, polar_decomposition, signm, sinhm, sinm,
498 spectral_condition_number, spectral_radius, sqrtm, sqrtm_parallel, tanhm, tanm,
499 tikhonov_regularization,
500 };
501 pub use super::matrixfree::{
502 block_diagonal_operator, conjugate_gradient as matrix_free_conjugate_gradient,
503 diagonal_operator, gmres as matrix_free_gmres, jacobi_preconditioner,
504 preconditioned_conjugate_gradient as matrix_free_preconditioned_conjugate_gradient,
505 LinearOperator, MatrixFreeOp,
506 };
507 pub use super::mixed_precision::{
509 convert, convert_2d, iterative_refinement_solve, mixed_precision_cond,
510 mixed_precision_dot_f32, mixed_precision_matmul, mixed_precision_matvec,
511 mixed_precision_qr, mixed_precision_solve, mixed_precision_svd,
512 };
513 pub use super::norm::{cond, matrix_norm, matrix_rank, vector_norm, vector_norm_parallel};
519 pub use super::optim::{block_matmul, strassen_matmul, tiled_matmul};
520 pub use super::perf_opt::{
521 blocked_matmul, inplace_add, inplace_scale, matmul_benchmark, optimized_transpose,
522 OptAlgorithm, OptConfig,
523 };
524 pub use super::preconditioners::{
525 analyze_preconditioner, create_preconditioner, preconditioned_conjugate_gradient,
526 preconditioned_gmres, AdaptivePreconditioner, BlockJacobiPreconditioner,
527 DiagonalPreconditioner, IncompleteCholeskyPreconditioner, IncompleteLUPreconditioner,
528 PolynomialPreconditioner, PreconditionerAnalysis, PreconditionerConfig, PreconditionerOp,
529 PreconditionerType,
530 };
531 pub use super::projection::{
532 gaussian_randommatrix, johnson_lindenstrauss_min_dim, johnson_lindenstrauss_transform,
533 project, sparse_randommatrix, very_sparse_randommatrix,
534 };
535 pub use super::quantization::calibration::{
536 calibrate_matrix, calibrate_vector, CalibrationConfig, CalibrationMethod,
537 };
538 #[cfg(feature = "simd")]
539 pub use super::quantization::simd::{
540 simd_quantized_dot, simd_quantized_matmul, simd_quantized_matvec,
541 };
542 pub use super::quantization::{
543 dequantize_matrix, dequantize_vector, fake_quantize, fake_quantize_vector, quantize_matrix,
544 quantize_matrix_per_channel, quantize_vector, quantized_dot, quantized_matmul,
545 quantized_matvec, QuantizationMethod, QuantizationParams, QuantizedDataType,
546 QuantizedMatrix, QuantizedVector,
547 };
548 pub use super::random::{
549 banded, diagonal, hilbert, low_rank, normal, orthogonal, permutation, random_correlation,
550 sparse, spd, toeplitz, uniform, vandermonde, with_condition_number, with_eigenvalues,
551 };
552 pub use super::random_matrices::{
553 random_complexmatrix, random_hermitian, randommatrix, Distribution1D, MatrixType,
554 };
555 pub use super::fft::{
562 apply_window, dct_1d, dst_1d, fft_1d, fft_2d, fft_3d, fft_convolve, fft_frequencies,
563 idct_1d, irfft_1d, periodogram_psd, rfft_1d, welch_psd, Complex32, Complex64, FFTAlgorithm,
564 FFTPlan, WindowFunction,
565 };
566 pub use super::generic::{
567 gdet, geig, gemm, gemv, ginv, gnorm, gqr, gsolve, gsvd, GenericEigen, GenericQR,
568 GenericSVD, LinalgScalar, PrecisionSelector,
569 };
570 pub use super::scalable::{
571 adaptive_decomposition, blocked_matmul as scalable_blocked_matmul, classify_aspect_ratio,
572 lq_decomposition, randomized_svd as scalable_randomized_svd, tsqr, AdaptiveResult,
573 AspectRatio, ScalableConfig,
574 };
575 #[cfg(feature = "simd")]
576 pub use super::simd_ops::{
577 simd_axpy_f32,
578 simd_axpy_f64,
579 simd_dot_f32,
580 simd_dot_f64,
581 simd_frobenius_norm_f32,
582 simd_frobenius_norm_f64,
583 simd_gemm_f32,
585 simd_gemm_f64,
586 simd_gemv_f32,
587 simd_gemv_f64,
588 simd_matmul_f32,
589 simd_matmul_f64,
590 simd_matmul_optimized_f32,
591 simd_matmul_optimized_f64,
592 simd_matvec_f32,
593 simd_matvec_f64,
594 simd_transpose_f32,
596 simd_transpose_f64,
597 simd_vector_norm_f32,
599 simd_vector_norm_f64,
600 simdmatrix_max_f32,
601 simdmatrix_max_f64,
602 simdmatrix_min_f32,
603 simdmatrix_min_f64,
604 GemmBlockSizes,
605 };
606 pub use super::solve::{lstsq, solve, solve_multiple, solve_triangular};
607 pub use super::sparse_dense::{
608 dense_sparse_matmul, dense_sparse_matvec, sparse_dense_add, sparse_dense_elementwise_mul,
609 sparse_dense_matmul, sparse_dense_matvec, sparse_dense_sub, sparse_from_ndarray,
610 SparseMatrixView,
611 };
612 pub use super::special::block_diag;
613 pub use super::specialized::{
614 specialized_to_operator, BandedMatrix, BlockTridiagonalMatrix, SpecializedMatrix,
615 SymmetricMatrix, TridiagonalMatrix,
616 };
617 pub use super::stats::{correlationmatrix, covariancematrix};
618 pub use super::structured::{
619 solve_circulant, solve_toeplitz, structured_to_operator, CirculantMatrix, HankelMatrix,
620 StructuredMatrix, ToeplitzMatrix,
621 };
622 #[cfg(feature = "tensor_contraction")]
623 pub use super::tensor_contraction::{batch_matmul, contract, einsum, hosvd};
624 pub use super::tensor_train::{tt_add, tt_decomposition, tt_hadamard, TTTensor};
625
626 #[cfg(feature = "autograd")]
635 pub mod autograd {
636 pub use super::super::autograd::*;
644 }
645
646 pub mod accelerated {
648 pub use super::super::blas_accelerated::{
650 dot, gemm, gemv, inv as fast_inv, matmul, norm, solve as fast_solve,
651 };
652 pub use super::super::lapack_accelerated::{
653 cholesky as fast_cholesky, eig as fast_eig, eigh as fast_eigh, lu as fast_lu,
654 qr as fast_qr, svd as fast_svd,
655 };
656 }
657
658 pub mod scipy_compat {
660 pub use super::super::compat::{
681 block_diag,
683 cholesky,
684 compat_solve as solve,
686 cond,
687 cosm,
688 det,
690 eig,
692 eig_banded,
693 eigh,
694 eigh_tridiagonal,
695 eigvals,
696 eigvals_banded,
697 eigvalsh,
698 eigvalsh_tridiagonal,
699 expm,
701 fractionalmatrix_power,
702 funm,
703 inv,
704 logm,
705 lstsq,
706 lu,
708 matrix_rank,
709 norm,
710 pinv,
711 polar,
712 qr,
713 rq,
714 schur,
715 sinm,
716 solve_banded,
717 solve_triangular,
718 sqrtm,
719 svd,
720 tanm,
721 vector_norm,
722 SvdResult,
724 };
725 }
726}