scirs2_sparse/linalg/
expm.rs

1//! Matrix exponential computation for sparse matrices
2//!
3//! This module implements the matrix exponential using the scaling and squaring
4//! method with Padé approximation.
5
6use crate::csr::CsrMatrix;
7use crate::error::{SparseError, SparseResult};
8use scirs2_core::numeric::{Float, NumAssign, One, SparseElement, Zero};
9use std::iter::Sum;
10
11/// Compute the matrix exponential using scaling and squaring with Padé approximation
12///
13/// This function computes exp(A) for a sparse matrix A using the scaling and
14/// squaring method combined with Padé approximation.
15///
16/// # Arguments
17///
18/// * `a` - The sparse matrix A (must be square)
19///
20/// # Returns
21///
22/// The matrix exponential exp(A) as a sparse matrix
23///
24/// # Implementation Details
25///
26/// Uses 13th order Padé approximation for high accuracy (machine precision).
27/// The algorithm automatically selects the appropriate scaling factor based
28/// on the matrix norm to ensure numerical stability.
29#[allow(dead_code)]
30pub fn expm<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
31where
32    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
33{
34    let (rows, cols) = a.shape();
35    if rows != cols {
36        return Err(SparseError::ValueError(
37            "Matrix must be square for expm".to_string(),
38        ));
39    }
40
41    // Compute the matrix infinity norm
42    let a_norm = matrix_inf_norm(a)?;
43
44    // Constants for order 13 Padé approximation
45    let theta_13 = F::from(5.371920351148152).unwrap();
46
47    // If the norm is small enough, use direct Padé approximation
48    if a_norm <= theta_13 {
49        return pade_approximation(a, 13);
50    }
51
52    // Otherwise, use scaling and squaring
53    // Find s such that ||A/2^s|| <= theta_13
54    let mut s = 0;
55    let mut scaled_norm = a_norm;
56    let two = F::from(2.0).unwrap();
57
58    while scaled_norm > theta_13 {
59        s += 1;
60        scaled_norm /= two;
61    }
62
63    // Compute A/2^s
64    let scale_factor = two.powi(s);
65    let scaled_a = scale_matrix(a, F::sparse_one() / scale_factor)?;
66
67    // Compute exp(A/2^s) using Padé approximation
68    let mut exp_scaled = pade_approximation(&scaled_a, 13)?;
69
70    // Square the result s times to get exp(A)
71    for _ in 0..s {
72        exp_scaled = exp_scaled.matmul(&exp_scaled)?;
73    }
74
75    Ok(exp_scaled)
76}
77
78/// Compute the Padé approximation of exp(A)
79///
80/// Uses the diagonal Padé approximant of order (p,p)
81#[allow(dead_code)]
82fn pade_approximation<F>(a: &CsrMatrix<F>, p: usize) -> SparseResult<CsrMatrix<F>>
83where
84    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
85{
86    let n = a.shape().0;
87
88    // Compute powers of A
89    let mut a_powers = vec![sparse_identity(n)?]; // A^0 = I
90    a_powers.push(a.clone()); // A^1 = A
91
92    // Compute A^2, A^3, ..., A^p
93    for i in 2..=p {
94        let prev = &a_powers[i - 1];
95        let power = prev.matmul(a)?;
96        a_powers.push(power);
97    }
98
99    // Compute Padé coefficients
100    let pade_coeffs = match p {
101        6 => vec![
102            F::from(1.0).unwrap(),
103            F::from(1.0 / 2.0).unwrap(),
104            F::from(3.0 / 26.0).unwrap(),
105            F::from(1.0 / 312.0).unwrap(),
106            F::from(1.0 / 10608.0).unwrap(),
107            F::from(1.0 / 358800.0).unwrap(),
108            F::from(1.0 / 17297280.0).unwrap(),
109        ],
110        13 => {
111            // Compute coefficients for Padé (13,13) approximant
112            // c_k = (2p-k)! p! / ((2p)! k! (p-k)!) for k = 0, 1, ..., p
113            let two_p = 26i64;
114            let p = 13i64;
115            let mut coeffs = Vec::with_capacity(14);
116
117            for k in 0..=p {
118                let mut num = 1.0;
119                let mut den = 1.0;
120
121                // (2p-k)! / (2p)!
122                for i in (two_p - k + 1)..=two_p {
123                    den *= i as f64;
124                }
125
126                // p! / (p-k)!
127                for i in (p - k + 1)..=p {
128                    num *= i as f64;
129                }
130
131                // 1 / k!
132                let mut k_fact = 1.0;
133                for i in 1..=k {
134                    k_fact *= i as f64;
135                }
136
137                coeffs.push(F::from(num / (den * k_fact)).unwrap());
138            }
139
140            coeffs
141        }
142        _ => {
143            // General formula for Padé coefficients
144            let mut coeffs = vec![F::sparse_zero(); p + 1];
145            let mut factorial: F = F::sparse_one();
146            for (i, coeff) in coeffs.iter_mut().enumerate().take(p + 1) {
147                if i > 0 {
148                    factorial *= F::from(i).unwrap();
149                }
150                let numerator = factorial;
151                let mut denominator = F::sparse_one();
152                for j in 1..=i {
153                    denominator *= F::from(p + 1 - j).unwrap();
154                }
155                for j in 1..=(p - i) {
156                    denominator *= F::from(j).unwrap();
157                }
158                *coeff = numerator / denominator;
159            }
160            coeffs
161        }
162    };
163
164    // Compute U and V for the Padé approximant
165    let mut u = sparse_zero(n)?;
166    let mut v = sparse_zero(n)?;
167
168    // U = sum of odd powers, V = sum of even powers
169    for (i, coeff) in pade_coeffs.iter().enumerate() {
170        let scaled_matrix = scale_matrix(&a_powers[i], *coeff)?;
171        if i % 2 == 0 {
172            v = sparse_add(&v, &scaled_matrix)?;
173        } else {
174            u = sparse_add(&u, &scaled_matrix)?;
175        }
176    }
177
178    // Compute (V - U)^(-1) * (V + U)
179    let neg_u = scale_matrix(&u, F::from(-1.0).unwrap())?;
180    let v_minus_u = sparse_add(&v, &neg_u)?;
181    let v_plus_u = sparse_add(&v, &u)?;
182
183    // Solve (V - U) * X = (V + U) for X
184    sparse_solve(&v_minus_u, &v_plus_u)
185}
186
187/// Compute the infinity norm of a sparse matrix
188#[allow(dead_code)]
189fn matrix_inf_norm<F>(a: &CsrMatrix<F>) -> SparseResult<F>
190where
191    F: Float + NumAssign + Sum + SparseElement + std::fmt::Debug,
192{
193    let mut max_row_sum = F::sparse_zero();
194
195    // For CSR format, efficiently compute row sums
196    for row in 0..a.rows() {
197        let start = a.indptr[row];
198        let end = a.indptr[row + 1];
199        let row_sum: F = a.data[start..end].iter().map(|x| x.abs()).sum();
200
201        if row_sum > max_row_sum {
202            max_row_sum = row_sum;
203        }
204    }
205
206    Ok(max_row_sum)
207}
208
209/// Scale a sparse matrix by a scalar
210#[allow(dead_code)]
211fn scale_matrix<F>(a: &CsrMatrix<F>, scale: F) -> SparseResult<CsrMatrix<F>>
212where
213    F: Float + NumAssign + SparseElement,
214{
215    let mut data = a.data.clone();
216    for val in data.iter_mut() {
217        *val *= scale;
218    }
219    CsrMatrix::from_raw_csr(data, a.indptr.clone(), a.indices.clone(), a.shape())
220}
221
222/// Create a sparse identity matrix
223#[allow(dead_code)]
224fn sparse_identity<F>(n: usize) -> SparseResult<CsrMatrix<F>>
225where
226    F: Float + Zero + One + SparseElement,
227{
228    let mut rows = Vec::with_capacity(n);
229    let mut cols = Vec::with_capacity(n);
230    let mut values = Vec::with_capacity(n);
231
232    for i in 0..n {
233        rows.push(i);
234        cols.push(i);
235        values.push(F::sparse_one());
236    }
237
238    CsrMatrix::new(values, rows, cols, (n, n))
239}
240
241/// Create a sparse zero matrix
242#[allow(dead_code)]
243fn sparse_zero<F>(n: usize) -> SparseResult<CsrMatrix<F>>
244where
245    F: Float + Zero + SparseElement,
246{
247    Ok(CsrMatrix::empty((n, n)))
248}
249
250/// Add two sparse matrices
251#[allow(dead_code)]
252fn sparse_add<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
253where
254    F: Float + NumAssign + SparseElement,
255{
256    if a.shape() != b.shape() {
257        return Err(SparseError::ShapeMismatch {
258            expected: a.shape(),
259            found: b.shape(),
260        });
261    }
262
263    let mut rows = Vec::new();
264    let mut cols = Vec::new();
265    let mut values = Vec::new();
266
267    for i in 0..a.rows() {
268        for j in 0..a.cols() {
269            let val = a.get(i, j) + b.get(i, j);
270            if val.abs() > F::epsilon() {
271                rows.push(i);
272                cols.push(j);
273                values.push(val);
274            }
275        }
276    }
277
278    CsrMatrix::new(values, rows, cols, a.shape())
279}
280
281/// Solve a linear system A * X = B for sparse matrices
282///
283/// Note: This is a placeholder - in practice you'd use a more sophisticated solver
284#[allow(dead_code)]
285fn sparse_solve<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
286where
287    F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
288{
289    use crate::linalg::interface::MatrixLinearOperator;
290    use crate::linalg::iterative::bicgstab;
291    use crate::linalg::iterative::BiCGSTABOptions;
292
293    let n = a.rows();
294    let mut result_rows = Vec::new();
295    let mut result_cols = Vec::new();
296    let mut result_values = Vec::new();
297
298    // Solve column by column
299    for col in 0..b.cols() {
300        // Extract the column from B
301        let b_col = (0..n).map(|row| b.get(row, col)).collect::<Vec<_>>();
302
303        // Create a linear operator for the matrix
304        let op = MatrixLinearOperator::new(a.clone());
305
306        // Create options for BiCGSTAB
307        let options = BiCGSTABOptions {
308            rtol: F::from(1e-10).unwrap(),
309            atol: F::from(1e-12).unwrap(),
310            max_iter: 1000,
311            x0: None,
312            left_preconditioner: None,
313            right_preconditioner: None,
314        };
315
316        // Use BiCGSTAB to solve A * x = b_col
317        let result = bicgstab(&op, &b_col, options)?;
318
319        // Check convergence
320        if !result.converged {
321            return Err(SparseError::IterativeSolverFailure(format!(
322                "BiCGSTAB failed to converge in {} iterations",
323                result.iterations
324            )));
325        }
326
327        // Add non-zero entries to result
328        for (row, &val) in result.x.iter().enumerate() {
329            if val.abs() > F::epsilon() {
330                result_rows.push(row);
331                result_cols.push(col);
332                result_values.push(val);
333            }
334        }
335    }
336
337    CsrMatrix::new(result_values, result_rows, result_cols, (n, b.cols()))
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use approx::assert_relative_eq;
344
345    #[test]
346    fn test_expm_identity() {
347        // exp(0) = I
348        let n = 3;
349        let zero_matrix = sparse_zero::<f64>(n).unwrap();
350        let exp_zero = expm(&zero_matrix).unwrap();
351
352        // Check that exp(0) is identity
353        for i in 0..n {
354            for j in 0..n {
355                let expected = if i == j { 1.0 } else { 0.0 };
356                let actual = exp_zero.get(i, j);
357                assert_relative_eq!(actual, expected, epsilon = 1e-10);
358            }
359        }
360    }
361
362    #[test]
363    fn test_expm_diagonal() {
364        // For diagonal matrix D, exp(D) is diagonal with exp(d_ii) on diagonal
365        let n = 3;
366        let diag_values = [0.5, 1.0, 2.0];
367        let mut rows = Vec::new();
368        let mut cols = Vec::new();
369        let mut values = Vec::new();
370
371        for (i, &val) in diag_values.iter().enumerate() {
372            rows.push(i);
373            cols.push(i);
374            values.push(val);
375        }
376
377        let diag_matrix = CsrMatrix::new(values, rows, cols, (n, n)).unwrap();
378        let exp_diag = expm(&diag_matrix).unwrap();
379
380        // Check diagonal values with high precision
381        for (i, &val) in diag_values.iter().enumerate() {
382            let expected = val.exp();
383            let actual = exp_diag.get(i, i);
384            assert_relative_eq!(actual, expected, epsilon = 1e-10);
385        }
386
387        // Check off-diagonal values are zero
388        for i in 0..n {
389            for j in 0..n {
390                if i != j {
391                    let actual = exp_diag.get(i, j);
392                    assert_relative_eq!(actual, 0.0, epsilon = 1e-10);
393                }
394            }
395        }
396    }
397
398    #[test]
399    fn test_expm_small_matrix() {
400        // Test on a small 2x2 matrix with known exponential
401        // A = [[0, 1], [0, 0]]
402        // exp(A) = [[1, 1], [0, 1]]
403        let rows = vec![0, 1];
404        let cols = vec![1, 0];
405        let values = vec![1.0, 0.0];
406
407        let a = CsrMatrix::new(values, rows, cols, (2, 2)).unwrap();
408        let exp_a = expm(&a).unwrap();
409
410        // Check the result
411        assert_relative_eq!(exp_a.get(0, 0), 1.0, epsilon = 1e-10);
412        assert_relative_eq!(exp_a.get(0, 1), 1.0, epsilon = 1e-10);
413        assert_relative_eq!(exp_a.get(1, 0), 0.0, epsilon = 1e-10);
414        assert_relative_eq!(exp_a.get(1, 1), 1.0, epsilon = 1e-10);
415    }
416}