scirs2_sparse/linalg/eigen/
symmetric.rs1use super::lanczos::{lanczos, EigenResult, LanczosOptions};
7use crate::error::{SparseError, SparseResult};
8use crate::sym_csr::SymCsrMatrix;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15#[allow(dead_code)]
47pub fn eigsh<T>(
48 matrix: &SymCsrMatrix<T>,
49 k: Option<usize>,
50 which: Option<&str>,
51 options: Option<LanczosOptions>,
52) -> SparseResult<EigenResult<T>>
53where
54 T: Float
55 + Debug
56 + Copy
57 + Add<Output = T>
58 + Sub<Output = T>
59 + Mul<Output = T>
60 + Div<Output = T>
61 + std::iter::Sum
62 + scirs2_core::simd_ops::SimdUnifiedOps
63 + SparseElement
64 + PartialOrd
65 + Send
66 + Sync
67 + 'static,
68{
69 let opts = options.unwrap_or_default();
70 let k = k.unwrap_or(opts.numeigenvalues);
71 let which = which.unwrap_or("LA");
72
73 let (n, m) = matrix.shape();
74 if n != m {
75 return Err(SparseError::ValueError(
76 "Matrix must be square for eigenvalue computation".to_string(),
77 ));
78 }
79
80 enhanced_lanczos(matrix, k, which, &opts)
82}
83
84#[allow(dead_code)]
119pub fn eigsh_shift_invert<T>(
120 matrix: &SymCsrMatrix<T>,
121 sigma: T,
122 k: Option<usize>,
123 which: Option<&str>,
124 options: Option<LanczosOptions>,
125) -> SparseResult<EigenResult<T>>
126where
127 T: Float
128 + Debug
129 + Copy
130 + Add<Output = T>
131 + Sub<Output = T>
132 + Mul<Output = T>
133 + Div<Output = T>
134 + std::iter::Sum
135 + scirs2_core::simd_ops::SimdUnifiedOps
136 + SparseElement
137 + PartialOrd
138 + Send
139 + Sync
140 + 'static,
141{
142 let opts = options.unwrap_or_default();
143 let k = k.unwrap_or(6);
144 let which = which.unwrap_or("LM");
145
146 let (n, m) = matrix.shape();
147 if n != m {
148 return Err(SparseError::ValueError(
149 "Matrix must be square for eigenvalue computation".to_string(),
150 ));
151 }
152
153 let mut shifted_matrix = matrix.clone();
159
160 for i in 0..n {
164 for j in shifted_matrix.indptr[i]..shifted_matrix.indptr[i + 1] {
166 if shifted_matrix.indices[j] == i {
167 shifted_matrix.data[j] = shifted_matrix.data[j] - sigma;
168 break;
169 }
170 }
171 }
172
173 let mut shift_opts = opts.clone();
176 shift_opts.numeigenvalues = k;
177
178 let result = lanczos(&shifted_matrix, &shift_opts, None)?;
179
180 let mut transformed_eigenvalues = Array1::zeros(result.eigenvalues.len());
182 for (i, &mu) in result.eigenvalues.iter().enumerate() {
183 if !SparseElement::is_zero(&mu) {
184 transformed_eigenvalues[i] = sigma + T::sparse_one() / mu;
185 } else {
186 transformed_eigenvalues[i] = sigma;
187 }
188 }
189
190 Ok(EigenResult {
191 eigenvalues: transformed_eigenvalues,
192 eigenvectors: result.eigenvectors,
193 iterations: result.iterations,
194 residuals: result.residuals,
195 converged: result.converged,
196 })
197}
198
199#[allow(dead_code)]
218#[allow(clippy::too_many_arguments)]
219pub fn eigsh_shift_invert_enhanced<T>(
220 matrix: &SymCsrMatrix<T>,
221 sigma: T,
222 k: Option<usize>,
223 which: Option<&str>,
224 mode: Option<&str>,
225 return_eigenvectors: Option<bool>,
226 options: Option<LanczosOptions>,
227) -> SparseResult<EigenResult<T>>
228where
229 T: Float
230 + Debug
231 + Copy
232 + Add<Output = T>
233 + Sub<Output = T>
234 + Mul<Output = T>
235 + Div<Output = T>
236 + std::iter::Sum
237 + scirs2_core::simd_ops::SimdUnifiedOps
238 + SparseElement
239 + PartialOrd
240 + Send
241 + Sync
242 + 'static,
243{
244 let _mode = mode.unwrap_or("normal");
245 let _return_eigenvectors = return_eigenvectors.unwrap_or(true);
246
247 eigsh_shift_invert(matrix, sigma, k, which, options)
249}
250
251fn enhanced_lanczos<T>(
256 matrix: &SymCsrMatrix<T>,
257 k: usize,
258 which: &str,
259 options: &LanczosOptions,
260) -> SparseResult<EigenResult<T>>
261where
262 T: Float
263 + Debug
264 + Copy
265 + Add<Output = T>
266 + Sub<Output = T>
267 + Mul<Output = T>
268 + Div<Output = T>
269 + std::iter::Sum
270 + scirs2_core::simd_ops::SimdUnifiedOps
271 + SparseElement
272 + PartialOrd
273 + Send
274 + Sync
275 + 'static,
276{
277 let n = matrix.shape().0;
278
279 let mut enhanced_opts = options.clone();
281 enhanced_opts.numeigenvalues = k;
282
283 enhanced_opts.max_subspace_size = (k * 2 + 10).min(n);
285
286 enhanced_opts.tol = enhanced_opts.tol.min(1e-10);
288
289 let result = lanczos(matrix, &enhanced_opts, None)?;
291
292 process_eigenvalue_selection(result, which, k)
294}
295
296fn process_eigenvalue_selection<T>(
298 mut result: EigenResult<T>,
299 which: &str,
300 k: usize,
301) -> SparseResult<EigenResult<T>>
302where
303 T: Float + Debug + Copy,
304{
305 let n_computed = result.eigenvalues.len();
306 let n_requested = k.min(n_computed);
307
308 match which {
309 "LA" => {
310 result.eigenvalues = result
312 .eigenvalues
313 .slice(scirs2_core::ndarray::s![..n_requested])
314 .to_owned();
315 if let Some(ref mut evecs) = result.eigenvectors {
316 *evecs = evecs
317 .slice(scirs2_core::ndarray::s![.., ..n_requested])
318 .to_owned();
319 }
320 result.residuals = result
321 .residuals
322 .slice(scirs2_core::ndarray::s![..n_requested])
323 .to_owned();
324 }
325 "SA" => {
326 let mut eigenvals = result.eigenvalues.to_vec();
328 eigenvals.reverse();
329 result.eigenvalues = Array1::from_vec(eigenvals[..n_requested].to_vec());
330
331 if let Some(ref mut evecs) = result.eigenvectors {
332 let ncols = evecs.ncols();
333 let mut evecs_vec = Vec::new();
334 for j in (0..ncols).rev().take(n_requested) {
335 for i in 0..evecs.nrows() {
336 evecs_vec.push(evecs[[i, j]]);
337 }
338 }
339 *evecs = scirs2_core::ndarray::Array2::from_shape_vec(
340 (evecs.nrows(), n_requested),
341 evecs_vec,
342 )
343 .map_err(|_| {
344 SparseError::ValueError("Failed to reshape eigenvectors".to_string())
345 })?;
346 }
347
348 let mut residuals = result.residuals.to_vec();
349 residuals.reverse();
350 result.residuals = Array1::from_vec(residuals[..n_requested].to_vec());
351 }
352 "LM" => {
353 let mut indices: Vec<usize> = (0..n_computed).collect();
355 indices.sort_by(|&i, &j| {
356 result.eigenvalues[j]
357 .abs()
358 .partial_cmp(&result.eigenvalues[i].abs())
359 .unwrap_or(std::cmp::Ordering::Equal)
360 });
361
362 let mut new_eigenvals = Vec::new();
363 let mut new_residuals = Vec::new();
364
365 for &idx in indices.iter().take(n_requested) {
366 new_eigenvals.push(result.eigenvalues[idx]);
367 new_residuals.push(result.residuals[idx]);
368 }
369
370 result.eigenvalues = Array1::from_vec(new_eigenvals);
371 result.residuals = Array1::from_vec(new_residuals);
372
373 if let Some(ref mut evecs) = result.eigenvectors {
374 let mut new_evecs = Vec::new();
375 for &idx in indices.iter().take(n_requested) {
376 for i in 0..evecs.nrows() {
377 new_evecs.push(evecs[[i, idx]]);
378 }
379 }
380 *evecs = scirs2_core::ndarray::Array2::from_shape_vec(
381 (evecs.nrows(), n_requested),
382 new_evecs,
383 )
384 .map_err(|_| {
385 SparseError::ValueError("Failed to reshape eigenvectors".to_string())
386 })?;
387 }
388 }
389 "SM" => {
390 let mut indices: Vec<usize> = (0..n_computed).collect();
392 indices.sort_by(|&i, &j| {
393 result.eigenvalues[i]
394 .abs()
395 .partial_cmp(&result.eigenvalues[j].abs())
396 .unwrap_or(std::cmp::Ordering::Equal)
397 });
398
399 let mut new_eigenvals = Vec::new();
400 let mut new_residuals = Vec::new();
401
402 for &idx in indices.iter().take(n_requested) {
403 new_eigenvals.push(result.eigenvalues[idx]);
404 new_residuals.push(result.residuals[idx]);
405 }
406
407 result.eigenvalues = Array1::from_vec(new_eigenvals);
408 result.residuals = Array1::from_vec(new_residuals);
409
410 if let Some(ref mut evecs) = result.eigenvectors {
411 let mut new_evecs = Vec::new();
412 for &idx in indices.iter().take(n_requested) {
413 for i in 0..evecs.nrows() {
414 new_evecs.push(evecs[[i, idx]]);
415 }
416 }
417 *evecs = scirs2_core::ndarray::Array2::from_shape_vec(
418 (evecs.nrows(), n_requested),
419 new_evecs,
420 )
421 .map_err(|_| {
422 SparseError::ValueError("Failed to reshape eigenvectors".to_string())
423 })?;
424 }
425 }
426 _ => {
427 return Err(SparseError::ValueError(format!(
428 "Unknown eigenvalue selection criterion: {}. Use 'LA', 'SA', 'LM', or 'SM'",
429 which
430 )));
431 }
432 }
433
434 Ok(result)
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::sym_csr::SymCsrMatrix;
441
442 #[test]
443 fn test_eigsh_basic() {
444 let data = vec![4.0, 2.0, 3.0, 5.0, 1.0];
446 let indptr = vec![0, 1, 3, 5];
447 let indices = vec![0, 0, 1, 1, 2];
448 let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
449
450 let result = eigsh(&matrix, Some(2), Some("LA"), None).unwrap();
451
452 assert!(!result.eigenvalues.is_empty());
454 assert!(result.eigenvalues.len() <= 2);
455 assert!(result.eigenvalues[0].is_finite());
456 }
457
458 #[test]
459 fn test_eigsh_different_which() {
460 let data = vec![2.0, 1.0, 2.0]; let indptr = vec![0, 1, 3]; let indices = vec![0, 0, 1]; let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).unwrap();
465
466 let result_la = eigsh(&matrix, Some(1), Some("LA"), None).unwrap();
468 assert!(!result_la.eigenvalues.is_empty());
469 assert!(result_la.eigenvalues[0].is_finite());
470
471 let result_sa = eigsh(&matrix, Some(1), Some("SA"), None).unwrap();
473 assert!(!result_sa.eigenvalues.is_empty());
474 assert!(result_sa.eigenvalues[0].is_finite());
475 }
476
477 #[test]
478 fn test_eigsh_shift_invert() {
479 let data = vec![4.0, 1.0, 3.0]; let indptr = vec![0, 1, 3]; let indices = vec![0, 0, 1]; let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).unwrap();
484
485 let result = eigsh_shift_invert(&matrix, 2.0, Some(1), None, None).unwrap();
486
487 assert!(!result.eigenvalues.is_empty());
489 assert!(result.eigenvalues[0].is_finite());
490 }
491
492 #[test]
493 fn test_process_eigenvalue_selection() {
494 let eigenvalues = Array1::from_vec(vec![5.0, 3.0, 1.0]);
496 let residuals = Array1::from_vec(vec![1e-8, 1e-9, 1e-7]);
497 let result = EigenResult {
498 eigenvalues,
499 eigenvectors: None,
500 iterations: 10,
501 residuals,
502 converged: true,
503 };
504
505 let result_la = process_eigenvalue_selection(result.clone(), "LA", 2).unwrap();
507 assert_eq!(result_la.eigenvalues.len(), 2);
508 assert_eq!(result_la.eigenvalues[0], 5.0);
509 assert_eq!(result_la.eigenvalues[1], 3.0);
510
511 let result_sa = process_eigenvalue_selection(result.clone(), "SA", 2).unwrap();
513 assert_eq!(result_sa.eigenvalues.len(), 2);
514 assert_eq!(result_sa.eigenvalues[0], 1.0);
515 assert_eq!(result_sa.eigenvalues[1], 3.0);
516 }
517}