1use crate::csr::CsrMatrix;
7use crate::error::{SparseError, SparseResult};
8use scirs2_core::numeric::{Float, NumAssign, One, SparseElement, Zero};
9use std::iter::Sum;
10
11#[allow(dead_code)]
30pub fn expm<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
31where
32 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
33{
34 let (rows, cols) = a.shape();
35 if rows != cols {
36 return Err(SparseError::ValueError(
37 "Matrix must be square for expm".to_string(),
38 ));
39 }
40
41 let a_norm = matrix_inf_norm(a)?;
43
44 let theta_13 = F::from(5.371920351148152).unwrap();
46
47 if a_norm <= theta_13 {
49 return pade_approximation(a, 13);
50 }
51
52 let mut s = 0;
55 let mut scaled_norm = a_norm;
56 let two = F::from(2.0).unwrap();
57
58 while scaled_norm > theta_13 {
59 s += 1;
60 scaled_norm /= two;
61 }
62
63 let scale_factor = two.powi(s);
65 let scaled_a = scale_matrix(a, F::sparse_one() / scale_factor)?;
66
67 let mut exp_scaled = pade_approximation(&scaled_a, 13)?;
69
70 for _ in 0..s {
72 exp_scaled = exp_scaled.matmul(&exp_scaled)?;
73 }
74
75 Ok(exp_scaled)
76}
77
78#[allow(dead_code)]
82fn pade_approximation<F>(a: &CsrMatrix<F>, p: usize) -> SparseResult<CsrMatrix<F>>
83where
84 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
85{
86 let n = a.shape().0;
87
88 let mut a_powers = vec![sparse_identity(n)?]; a_powers.push(a.clone()); for i in 2..=p {
94 let prev = &a_powers[i - 1];
95 let power = prev.matmul(a)?;
96 a_powers.push(power);
97 }
98
99 let pade_coeffs = match p {
101 6 => vec![
102 F::from(1.0).unwrap(),
103 F::from(1.0 / 2.0).unwrap(),
104 F::from(3.0 / 26.0).unwrap(),
105 F::from(1.0 / 312.0).unwrap(),
106 F::from(1.0 / 10608.0).unwrap(),
107 F::from(1.0 / 358800.0).unwrap(),
108 F::from(1.0 / 17297280.0).unwrap(),
109 ],
110 13 => {
111 let two_p = 26i64;
114 let p = 13i64;
115 let mut coeffs = Vec::with_capacity(14);
116
117 for k in 0..=p {
118 let mut num = 1.0;
119 let mut den = 1.0;
120
121 for i in (two_p - k + 1)..=two_p {
123 den *= i as f64;
124 }
125
126 for i in (p - k + 1)..=p {
128 num *= i as f64;
129 }
130
131 let mut k_fact = 1.0;
133 for i in 1..=k {
134 k_fact *= i as f64;
135 }
136
137 coeffs.push(F::from(num / (den * k_fact)).unwrap());
138 }
139
140 coeffs
141 }
142 _ => {
143 let mut coeffs = vec![F::sparse_zero(); p + 1];
145 let mut factorial: F = F::sparse_one();
146 for (i, coeff) in coeffs.iter_mut().enumerate().take(p + 1) {
147 if i > 0 {
148 factorial *= F::from(i).unwrap();
149 }
150 let numerator = factorial;
151 let mut denominator = F::sparse_one();
152 for j in 1..=i {
153 denominator *= F::from(p + 1 - j).unwrap();
154 }
155 for j in 1..=(p - i) {
156 denominator *= F::from(j).unwrap();
157 }
158 *coeff = numerator / denominator;
159 }
160 coeffs
161 }
162 };
163
164 let mut u = sparse_zero(n)?;
166 let mut v = sparse_zero(n)?;
167
168 for (i, coeff) in pade_coeffs.iter().enumerate() {
170 let scaled_matrix = scale_matrix(&a_powers[i], *coeff)?;
171 if i % 2 == 0 {
172 v = sparse_add(&v, &scaled_matrix)?;
173 } else {
174 u = sparse_add(&u, &scaled_matrix)?;
175 }
176 }
177
178 let neg_u = scale_matrix(&u, F::from(-1.0).unwrap())?;
180 let v_minus_u = sparse_add(&v, &neg_u)?;
181 let v_plus_u = sparse_add(&v, &u)?;
182
183 sparse_solve(&v_minus_u, &v_plus_u)
185}
186
187#[allow(dead_code)]
189fn matrix_inf_norm<F>(a: &CsrMatrix<F>) -> SparseResult<F>
190where
191 F: Float + NumAssign + Sum + SparseElement + std::fmt::Debug,
192{
193 let mut max_row_sum = F::sparse_zero();
194
195 for row in 0..a.rows() {
197 let start = a.indptr[row];
198 let end = a.indptr[row + 1];
199 let row_sum: F = a.data[start..end].iter().map(|x| x.abs()).sum();
200
201 if row_sum > max_row_sum {
202 max_row_sum = row_sum;
203 }
204 }
205
206 Ok(max_row_sum)
207}
208
209#[allow(dead_code)]
211fn scale_matrix<F>(a: &CsrMatrix<F>, scale: F) -> SparseResult<CsrMatrix<F>>
212where
213 F: Float + NumAssign + SparseElement,
214{
215 let mut data = a.data.clone();
216 for val in data.iter_mut() {
217 *val *= scale;
218 }
219 CsrMatrix::from_raw_csr(data, a.indptr.clone(), a.indices.clone(), a.shape())
220}
221
222#[allow(dead_code)]
224fn sparse_identity<F>(n: usize) -> SparseResult<CsrMatrix<F>>
225where
226 F: Float + Zero + One + SparseElement,
227{
228 let mut rows = Vec::with_capacity(n);
229 let mut cols = Vec::with_capacity(n);
230 let mut values = Vec::with_capacity(n);
231
232 for i in 0..n {
233 rows.push(i);
234 cols.push(i);
235 values.push(F::sparse_one());
236 }
237
238 CsrMatrix::new(values, rows, cols, (n, n))
239}
240
241#[allow(dead_code)]
243fn sparse_zero<F>(n: usize) -> SparseResult<CsrMatrix<F>>
244where
245 F: Float + Zero + SparseElement,
246{
247 Ok(CsrMatrix::empty((n, n)))
248}
249
250#[allow(dead_code)]
252fn sparse_add<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
253where
254 F: Float + NumAssign + SparseElement,
255{
256 if a.shape() != b.shape() {
257 return Err(SparseError::ShapeMismatch {
258 expected: a.shape(),
259 found: b.shape(),
260 });
261 }
262
263 let mut rows = Vec::new();
264 let mut cols = Vec::new();
265 let mut values = Vec::new();
266
267 for i in 0..a.rows() {
268 for j in 0..a.cols() {
269 let val = a.get(i, j) + b.get(i, j);
270 if val.abs() > F::epsilon() {
271 rows.push(i);
272 cols.push(j);
273 values.push(val);
274 }
275 }
276 }
277
278 CsrMatrix::new(values, rows, cols, a.shape())
279}
280
281#[allow(dead_code)]
285fn sparse_solve<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
286where
287 F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
288{
289 use crate::linalg::interface::MatrixLinearOperator;
290 use crate::linalg::iterative::bicgstab;
291 use crate::linalg::iterative::BiCGSTABOptions;
292
293 let n = a.rows();
294 let mut result_rows = Vec::new();
295 let mut result_cols = Vec::new();
296 let mut result_values = Vec::new();
297
298 for col in 0..b.cols() {
300 let b_col = (0..n).map(|row| b.get(row, col)).collect::<Vec<_>>();
302
303 let op = MatrixLinearOperator::new(a.clone());
305
306 let options = BiCGSTABOptions {
308 rtol: F::from(1e-10).unwrap(),
309 atol: F::from(1e-12).unwrap(),
310 max_iter: 1000,
311 x0: None,
312 left_preconditioner: None,
313 right_preconditioner: None,
314 };
315
316 let result = bicgstab(&op, &b_col, options)?;
318
319 if !result.converged {
321 return Err(SparseError::IterativeSolverFailure(format!(
322 "BiCGSTAB failed to converge in {} iterations",
323 result.iterations
324 )));
325 }
326
327 for (row, &val) in result.x.iter().enumerate() {
329 if val.abs() > F::epsilon() {
330 result_rows.push(row);
331 result_cols.push(col);
332 result_values.push(val);
333 }
334 }
335 }
336
337 CsrMatrix::new(result_values, result_rows, result_cols, (n, b.cols()))
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use approx::assert_relative_eq;
344
345 #[test]
346 fn test_expm_identity() {
347 let n = 3;
349 let zero_matrix = sparse_zero::<f64>(n).unwrap();
350 let exp_zero = expm(&zero_matrix).unwrap();
351
352 for i in 0..n {
354 for j in 0..n {
355 let expected = if i == j { 1.0 } else { 0.0 };
356 let actual = exp_zero.get(i, j);
357 assert_relative_eq!(actual, expected, epsilon = 1e-10);
358 }
359 }
360 }
361
362 #[test]
363 fn test_expm_diagonal() {
364 let n = 3;
366 let diag_values = [0.5, 1.0, 2.0];
367 let mut rows = Vec::new();
368 let mut cols = Vec::new();
369 let mut values = Vec::new();
370
371 for (i, &val) in diag_values.iter().enumerate() {
372 rows.push(i);
373 cols.push(i);
374 values.push(val);
375 }
376
377 let diag_matrix = CsrMatrix::new(values, rows, cols, (n, n)).unwrap();
378 let exp_diag = expm(&diag_matrix).unwrap();
379
380 for (i, &val) in diag_values.iter().enumerate() {
382 let expected = val.exp();
383 let actual = exp_diag.get(i, i);
384 assert_relative_eq!(actual, expected, epsilon = 1e-10);
385 }
386
387 for i in 0..n {
389 for j in 0..n {
390 if i != j {
391 let actual = exp_diag.get(i, j);
392 assert_relative_eq!(actual, 0.0, epsilon = 1e-10);
393 }
394 }
395 }
396 }
397
398 #[test]
399 fn test_expm_small_matrix() {
400 let rows = vec![0, 1];
404 let cols = vec![1, 0];
405 let values = vec![1.0, 0.0];
406
407 let a = CsrMatrix::new(values, rows, cols, (2, 2)).unwrap();
408 let exp_a = expm(&a).unwrap();
409
410 assert_relative_eq!(exp_a.get(0, 0), 1.0, epsilon = 1e-10);
412 assert_relative_eq!(exp_a.get(0, 1), 1.0, epsilon = 1e-10);
413 assert_relative_eq!(exp_a.get(1, 0), 0.0, epsilon = 1e-10);
414 assert_relative_eq!(exp_a.get(1, 1), 1.0, epsilon = 1e-10);
415 }
416}