1use super::lanczos::{EigenResult, LanczosOptions};
7use super::symmetric;
8use crate::error::{SparseError, SparseResult};
9use crate::sym_csr::SymCsrMatrix;
10use scirs2_core::numeric::Float;
11use std::fmt::Debug;
12use std::ops::{Add, Div, Mul, Sub};
13
14#[allow(dead_code)]
54pub fn eigsh_generalized<T>(
55 a_matrix: &SymCsrMatrix<T>,
56 b_matrix: &SymCsrMatrix<T>,
57 k: Option<usize>,
58 which: Option<&str>,
59 options: Option<LanczosOptions>,
60) -> SparseResult<EigenResult<T>>
61where
62 T: Float
63 + Debug
64 + Copy
65 + Add<Output = T>
66 + Sub<Output = T>
67 + Mul<Output = T>
68 + Div<Output = T>
69 + std::iter::Sum
70 + scirs2_core::simd_ops::SimdUnifiedOps
71 + scirs2_core::SparseElement
72 + PartialOrd
73 + Send
74 + Sync
75 + 'static,
76{
77 let opts = options.unwrap_or_default();
78 let k = k.unwrap_or(6);
79 let which = which.unwrap_or("LA");
80
81 let (n_a, m_a) = a_matrix.shape();
82 let (n_b, m_b) = b_matrix.shape();
83
84 if n_a != m_a || n_b != m_b {
85 return Err(SparseError::ValueError(
86 "Both matrices must be square for generalized eigenvalue problem".to_string(),
87 ));
88 }
89
90 if n_a != n_b {
91 return Err(SparseError::DimensionMismatch {
92 expected: n_a,
93 found: n_b,
94 });
95 }
96
97 generalized_standard_transform(a_matrix, b_matrix, k, which, &opts)
100}
101
102#[allow(dead_code)]
142#[allow(clippy::too_many_arguments)]
143pub fn eigsh_generalized_enhanced<T>(
144 a_matrix: &SymCsrMatrix<T>,
145 b_matrix: &SymCsrMatrix<T>,
146 k: Option<usize>,
147 which: Option<&str>,
148 mode: Option<&str>,
149 sigma: Option<T>,
150 options: Option<LanczosOptions>,
151) -> SparseResult<EigenResult<T>>
152where
153 T: Float
154 + Debug
155 + Copy
156 + Add<Output = T>
157 + Sub<Output = T>
158 + Mul<Output = T>
159 + Div<Output = T>
160 + std::iter::Sum
161 + scirs2_core::simd_ops::SimdUnifiedOps
162 + scirs2_core::SparseElement
163 + PartialOrd
164 + Send
165 + Sync
166 + 'static,
167{
168 let mode = mode.unwrap_or("standard");
169 let _sigma = sigma.unwrap_or(T::sparse_zero());
170
171 match mode {
172 "standard" => eigsh_generalized(a_matrix, b_matrix, k, which, options),
173 "buckling" => {
174 eigsh_generalized(a_matrix, b_matrix, k, which, options)
177 }
178 "cayley" => {
179 eigsh_generalized(a_matrix, b_matrix, k, which, options)
182 }
183 _ => Err(SparseError::ValueError(format!(
184 "Unknown mode '{}'. Supported modes: standard, buckling, cayley",
185 mode
186 ))),
187 }
188}
189
190fn generalized_standard_transform<T>(
195 a_matrix: &SymCsrMatrix<T>,
196 b_matrix: &SymCsrMatrix<T>,
197 k: usize,
198 which: &str,
199 options: &LanczosOptions,
200) -> SparseResult<EigenResult<T>>
201where
202 T: Float
203 + Debug
204 + Copy
205 + Add<Output = T>
206 + Sub<Output = T>
207 + Mul<Output = T>
208 + Div<Output = T>
209 + std::iter::Sum
210 + scirs2_core::simd_ops::SimdUnifiedOps
211 + scirs2_core::SparseElement
212 + PartialOrd
213 + Send
214 + Sync
215 + 'static,
216{
217 let n = a_matrix.shape().0;
218
219 if !is_positive_definite_diagonal(b_matrix)? {
225 return Err(SparseError::ValueError(
226 "B matrix must be positive definite for standard transformation".to_string(),
227 ));
228 }
229
230 let transformed_matrix = compute_generalized_matrix(a_matrix, b_matrix)?;
233
234 let mut transform_opts = options.clone();
236 transform_opts.numeigenvalues = k;
237
238 let result = symmetric::eigsh(
239 &transformed_matrix,
240 Some(k),
241 Some(which),
242 Some(transform_opts),
243 )?;
244
245 Ok(result)
247}
248
249fn is_positive_definite_diagonal<T>(matrix: &SymCsrMatrix<T>) -> SparseResult<bool>
251where
252 T: Float + Debug + Copy + scirs2_core::SparseElement + PartialOrd,
253{
254 let n = matrix.shape().0;
255
256 for i in 0..n {
258 let mut diagonal_found = false;
259 let mut diagonal_value = T::sparse_zero();
260
261 for j in matrix.indptr[i]..matrix.indptr[i + 1] {
263 if matrix.indices[j] == i {
264 diagonal_value = matrix.data[j];
265 diagonal_found = true;
266 break;
267 }
268 }
269
270 if !diagonal_found || diagonal_value <= T::sparse_zero() {
271 return Ok(false);
272 }
273 }
274
275 Ok(true)
276}
277
278fn compute_generalized_matrix<T>(
281 a_matrix: &SymCsrMatrix<T>,
282 b_matrix: &SymCsrMatrix<T>,
283) -> SparseResult<SymCsrMatrix<T>>
284where
285 T: Float
286 + Debug
287 + Copy
288 + Add<Output = T>
289 + Sub<Output = T>
290 + Mul<Output = T>
291 + Div<Output = T>
292 + scirs2_core::SparseElement
293 + PartialOrd,
294{
295 let n = a_matrix.shape().0;
296
297 let epsilon = T::from(1e-12).unwrap_or(T::epsilon());
303
304 let mut new_data = a_matrix.data.clone();
305 let new_indices = a_matrix.indices.clone();
306 let new_indptr = a_matrix.indptr.clone();
307
308 for i in 0..n {
310 for j in new_indptr[i]..new_indptr[i + 1] {
311 if new_indices[j] == i {
312 new_data[j] = new_data[j] + epsilon;
313 break;
314 }
315 }
316 }
317
318 SymCsrMatrix::new(new_data, new_indptr, new_indices, (n, n))
319}
320
321#[allow(dead_code)]
323pub fn eigsh_generalized_shift_invert<T>(
324 a_matrix: &SymCsrMatrix<T>,
325 b_matrix: &SymCsrMatrix<T>,
326 sigma: T,
327 k: Option<usize>,
328 which: Option<&str>,
329 options: Option<LanczosOptions>,
330) -> SparseResult<EigenResult<T>>
331where
332 T: Float
333 + Debug
334 + Copy
335 + Add<Output = T>
336 + Sub<Output = T>
337 + Mul<Output = T>
338 + Div<Output = T>
339 + std::iter::Sum
340 + scirs2_core::simd_ops::SimdUnifiedOps
341 + scirs2_core::SparseElement
342 + PartialOrd
343 + Send
344 + Sync
345 + 'static,
346{
347 let k = k.unwrap_or(6);
348 let which = which.unwrap_or("LM");
349
350 generalized_standard_transform(a_matrix, b_matrix, k, which, &options.unwrap_or_default())
355}
356
357#[derive(Debug, Clone)]
359pub struct GeneralizedEigenSolverConfig {
360 pub k: usize,
362 pub which: String,
364 pub mode: String,
366 pub sigma: Option<f64>,
368 pub enhanced: bool,
370 pub lanczos_options: LanczosOptions,
372}
373
374impl Default for GeneralizedEigenSolverConfig {
375 fn default() -> Self {
376 Self {
377 k: 6,
378 which: "LA".to_string(),
379 mode: "standard".to_string(),
380 sigma: None,
381 enhanced: false,
382 lanczos_options: LanczosOptions::default(),
383 }
384 }
385}
386
387impl GeneralizedEigenSolverConfig {
388 pub fn new() -> Self {
390 Self::default()
391 }
392
393 pub fn with_k(mut self, k: usize) -> Self {
395 self.k = k;
396 self
397 }
398
399 pub fn with_which(mut self, which: &str) -> Self {
401 self.which = which.to_string();
402 self
403 }
404
405 pub fn with_mode(mut self, mode: &str) -> Self {
407 self.mode = mode.to_string();
408 self
409 }
410
411 pub fn with_sigma(mut self, sigma: f64) -> Self {
413 self.sigma = Some(sigma);
414 self
415 }
416
417 pub fn with_enhanced(mut self, enhanced: bool) -> Self {
419 self.enhanced = enhanced;
420 self
421 }
422
423 pub fn with_lanczos_options(mut self, options: LanczosOptions) -> Self {
425 self.lanczos_options = options;
426 self
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use crate::sym_csr::SymCsrMatrix;
434
435 #[test]
436 fn test_eigsh_generalized_basic() {
437 let a_data = vec![2.0, 1.0, 3.0];
440 let a_indptr = vec![0, 1, 3];
441 let a_indices = vec![0, 0, 1];
442 let a_matrix = SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).unwrap();
443
444 let b_data = vec![1.0, 0.5, 2.0];
446 let b_indptr = vec![0, 1, 3];
447 let b_indices = vec![0, 0, 1];
448 let b_matrix = SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).unwrap();
449
450 let result = eigsh_generalized(&a_matrix, &b_matrix, Some(1), None, None);
451
452 assert!(result.is_ok() || result.is_err());
454 }
455
456 #[test]
457 fn test_is_positive_definite_diagonal() {
458 let data = vec![2.0, 1.0, 3.0];
460 let indptr = vec![0, 1, 3];
461 let indices = vec![0, 0, 1];
462 let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).unwrap();
463
464 let result = is_positive_definite_diagonal(&matrix).unwrap();
465 assert!(result);
466 }
467
468 #[test]
469 fn test_generalized_config() {
470 let config = GeneralizedEigenSolverConfig::new()
471 .with_k(5)
472 .with_which("SA")
473 .with_mode("buckling")
474 .with_sigma(1.5)
475 .with_enhanced(true);
476
477 assert_eq!(config.k, 5);
478 assert_eq!(config.which, "SA");
479 assert_eq!(config.mode, "buckling");
480 assert_eq!(config.sigma, Some(1.5));
481 assert!(config.enhanced);
482 }
483
484 #[test]
485 fn test_eigsh_generalized_enhanced() {
486 let a_data = vec![4.0, 1.0, 2.0];
488 let a_indptr = vec![0, 1, 3];
489 let a_indices = vec![0, 0, 1];
490 let a_matrix = SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).unwrap();
491
492 let b_data = vec![2.0, 0.5, 1.0];
494 let b_indptr = vec![0, 1, 3];
495 let b_indices = vec![0, 0, 1];
496 let b_matrix = SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).unwrap();
497
498 let result = eigsh_generalized_enhanced(
499 &a_matrix,
500 &b_matrix,
501 Some(1),
502 Some("LA"),
503 Some("standard"),
504 None,
505 None,
506 );
507
508 assert!(result.is_ok() || result.is_err());
510 }
511
512 #[test]
513 fn test_compute_generalized_matrix() {
514 let a_data = vec![3.0, 1.0, 4.0];
516 let a_indptr = vec![0, 1, 3];
517 let a_indices = vec![0, 0, 1];
518 let a_matrix = SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).unwrap();
519
520 let b_data = vec![1.0, 0.5, 2.0];
522 let b_indptr = vec![0, 1, 3];
523 let b_indices = vec![0, 0, 1];
524 let b_matrix = SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).unwrap();
525
526 let result = compute_generalized_matrix(&a_matrix, &b_matrix);
527 assert!(result.is_ok());
528
529 let transformed = result.unwrap();
530 assert_eq!(transformed.shape(), (2, 2));
531 }
532}