scirs2_optimize/sparse_numdiff/
hessian.rs

1//! Sparse Hessian computation using finite differences
2//!
3//! This module provides functions for computing sparse Hessian matrices
4//! using various finite difference methods.
5
6use ndarray::{Array1, ArrayView1};
7use scirs2_core::parallel_ops::*;
8use scirs2_sparse::{csr_array::CsrArray, sparray::SparseArray};
9
10use super::coloring::determine_column_groups;
11use super::finite_diff::{compute_step_sizes, SparseFiniteDiffOptions};
12use crate::error::OptimizeError;
13
14// Helper function to replace get_index and set_value_by_index which are not available in CsrArray
15fn update_sparse_value(matrix: &mut CsrArray<f64>, row: usize, col: usize, value: f64) {
16    // Only update if the position is non-zero in the sparsity pattern and set operation succeeds
17    if matrix.get(row, col) != 0.0 && matrix.set(row, col, value).is_err() {
18        // If this fails, just silently continue
19    }
20}
21
22// Helper function to check if a position exists in the sparsity pattern
23fn exists_in_sparsity(matrix: &CsrArray<f64>, row: usize, col: usize) -> bool {
24    matrix.get(row, col) != 0.0
25}
26
27/// Computes a sparse Hessian matrix using finite differences
28///
29/// # Arguments
30///
31/// * `func` - Function to differentiate, takes ArrayView1<f64> and returns f64
32/// * `grad` - Optional gradient function, takes ArrayView1<f64> and returns Array1<f64>
33/// * `x` - Point at which to compute the Hessian
34/// * `f0` - Function value at `x` (if None, computed internally)
35/// * `g0` - Gradient value at `x` (if None, computed internally)
36/// * `sparsity_pattern` - Sparse matrix indicating the known sparsity pattern (if None, dense Hessian)
37/// * `options` - Options for finite differences computation
38///
39/// # Returns
40///
41/// * `CsrArray<f64>` - Sparse Hessian matrix in CSR format
42///
43pub fn sparse_hessian<F, G>(
44    func: F,
45    grad: Option<G>,
46    x: &ArrayView1<f64>,
47    f0: Option<f64>,
48    g0: Option<&Array1<f64>>,
49    sparsity_pattern: Option<&CsrArray<f64>>,
50    options: Option<SparseFiniteDiffOptions>,
51) -> Result<CsrArray<f64>, OptimizeError>
52where
53    F: Fn(&ArrayView1<f64>) -> f64 + Sync,
54    G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
55{
56    let options = options.unwrap_or_default();
57    let n = x.len();
58
59    // If gradient function is provided, use it to compute Hessian via forward differences
60    // on the gradient
61    if let Some(gradient_fn) = grad {
62        return compute_hessian_from_gradient(gradient_fn, x, g0, sparsity_pattern, &options);
63    }
64
65    // If no sparsity pattern provided, create a dense one
66    let sparsity_owned: CsrArray<f64>;
67    let sparsity = match sparsity_pattern {
68        Some(p) => {
69            // Validate sparsity pattern
70            if p.shape().0 != n || p.shape().1 != n {
71                return Err(OptimizeError::ValueError(format!(
72                    "Sparsity pattern shape {:?} does not match input dimension {}",
73                    p.shape(),
74                    n
75                )));
76            }
77            p
78        }
79        None => {
80            // Create dense sparsity pattern
81            let mut data = Vec::with_capacity(n * n);
82            let mut rows = Vec::with_capacity(n * n);
83            let mut cols = Vec::with_capacity(n * n);
84
85            for i in 0..n {
86                for j in 0..n {
87                    data.push(1.0);
88                    rows.push(i);
89                    cols.push(j);
90                }
91            }
92
93            sparsity_owned = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
94            &sparsity_owned
95        }
96    };
97
98    // Ensure sparsity pattern is symmetric (Hessian is symmetric)
99    // In practice, we only need to compute the upper triangle and then
100    // fill in the lower triangle at the end
101    let symmetric_sparsity = make_symmetric_sparsity(sparsity)?;
102
103    // Choose implementation based on specified method
104    let result = match options.method.as_str() {
105        "2-point" => {
106            let f0_val = f0.unwrap_or_else(|| func(x));
107            compute_hessian_2point(func, x, f0_val, &symmetric_sparsity, &options)
108        }
109        "3-point" => compute_hessian_3point(func, x, &symmetric_sparsity, &options),
110        "cs" => compute_hessian_complex_step(func, x, &symmetric_sparsity, &options),
111        _ => Err(OptimizeError::ValueError(format!(
112            "Unknown method: {}. Valid options are '2-point', '3-point', and 'cs'",
113            options.method
114        ))),
115    }?;
116
117    // Fill in the lower triangle to ensure symmetry
118    fill_symmetric_hessian(&result)
119}
120
121/// Computes Hessian from a gradient function using forward differences
122fn compute_hessian_from_gradient<G>(
123    grad_fn: G,
124    x: &ArrayView1<f64>,
125    g0: Option<&Array1<f64>>,
126    sparsity_pattern: Option<&CsrArray<f64>>,
127    options: &SparseFiniteDiffOptions,
128) -> Result<CsrArray<f64>, OptimizeError>
129where
130    G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
131{
132    let _n = x.len();
133
134    // Compute g0 if not provided
135    let g0_owned: Array1<f64>;
136    let g0_ref = match g0 {
137        Some(g) => g,
138        None => {
139            g0_owned = grad_fn(x);
140            &g0_owned
141        }
142    };
143
144    // The gradient function can be treated as a vector-valued function,
145    // so we can use sparse_jacobian to compute the Hessian (which is the Jacobian of the gradient)
146    let jac_options = SparseFiniteDiffOptions {
147        method: options.method.clone(),
148        rel_step: options.rel_step,
149        abs_step: options.abs_step,
150        bounds: options.bounds.clone(),
151        parallel: options.parallel.clone(),
152        seed: options.seed,
153        max_group_size: options.max_group_size,
154    };
155
156    // Use sparse_jacobian to compute the Hessian
157    let hessian = super::jacobian::sparse_jacobian(
158        grad_fn,
159        x,
160        Some(g0_ref),
161        sparsity_pattern,
162        Some(jac_options),
163    )?;
164
165    // Ensure the Hessian is symmetric
166    fill_symmetric_hessian(&hessian)
167}
168
169/// Computes Hessian using 2-point finite differences
170fn compute_hessian_2point<F>(
171    func: F,
172    x: &ArrayView1<f64>,
173    f0: f64,
174    sparsity: &CsrArray<f64>,
175    options: &SparseFiniteDiffOptions,
176) -> Result<CsrArray<f64>, OptimizeError>
177where
178    F: Fn(&ArrayView1<f64>) -> f64 + Sync,
179{
180    let _n = x.len();
181
182    // Determine column groups using a graph coloring algorithm
183    let groups = determine_column_groups(sparsity, None, None)?;
184
185    // Compute step sizes
186    let h = compute_step_sizes(x, options);
187
188    // Create result matrix with the same sparsity pattern as the upper triangle
189    let (rows, cols, _) = sparsity.find();
190    let (m, n) = sparsity.shape();
191    let zeros = vec![0.0; rows.len()];
192    let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n), false)?;
193
194    // Create a mutable copy of x for perturbing
195    let mut x_perturbed = x.to_owned();
196
197    // Choose between parallel and serial execution
198    let parallel = options
199        .parallel
200        .as_ref()
201        .map(|p| p.num_workers.unwrap_or(1) > 1)
202        .unwrap_or(false);
203
204    // First set of function evaluations for the diagonal elements
205    let diag_evals: Vec<f64> = if parallel {
206        (0..n)
207            .into_par_iter()
208            .map(|i| {
209                let mut x_local = x.to_owned();
210                x_local[i] += h[i];
211                func(&x_local.view())
212            })
213            .collect()
214    } else {
215        let mut diag_vals = vec![0.0; n];
216        for i in 0..n {
217            x_perturbed[i] += h[i];
218            diag_vals[i] = func(&x_perturbed.view());
219            x_perturbed[i] = x[i];
220        }
221        diag_vals
222    };
223
224    // Set diagonal elements of the Hessian
225    for i in 0..n {
226        // Calculate second derivative
227        let d2f_dxi2 = (diag_evals[i] - 2.0 * f0 + diag_evals[i]) / (h[i] * h[i]);
228
229        // Update value if in sparsity pattern
230        update_sparse_value(&mut hess, i, i, d2f_dxi2);
231    }
232
233    // Now compute off-diagonal elements
234    if parallel {
235        // For parallel evaluation, we need to collect the derivatives first and apply them later
236        let derivatives: Vec<(usize, usize, f64)> = groups
237            .par_iter()
238            .flat_map(|group| {
239                let mut derivatives = Vec::new();
240                let mut x_local = x.to_owned();
241
242                for &j in group {
243                    // Only compute upper triangle
244                    for i in 0..j {
245                        if exists_in_sparsity(&hess, i, j) {
246                            // Apply perturbation for both indices
247                            x_local[i] += h[i];
248                            x_local[j] += h[j];
249
250                            // f(x + h_i*e_i + h_j*e_j)
251                            let f_ij = func(&x_local.view());
252
253                            // f(x + h_i*e_i)
254                            x_local[j] = x[j];
255                            let f_i = diag_evals[i];
256
257                            // f(x + h_j*e_j)
258                            x_local[i] = x[i];
259                            x_local[j] += h[j];
260                            let f_j = diag_evals[j];
261
262                            // Mixed partial derivative
263                            let d2f_dxidxj = (f_ij - f_i - f_j + f0) / (h[i] * h[j]);
264
265                            // Collect derivative
266                            derivatives.push((i, j, d2f_dxidxj));
267
268                            // Reset
269                            x_local[j] = x[j];
270                        }
271                    }
272                }
273
274                derivatives
275            })
276            .collect();
277
278        // Now apply all derivatives
279        for (i, j, d2f_dxidxj) in derivatives {
280            if hess.set(i, j, d2f_dxidxj).is_err() {
281                // If this fails, just silently continue
282            }
283        }
284    } else {
285        for group in &groups {
286            for &j in group {
287                // Only compute upper triangle
288                for i in 0..j {
289                    if exists_in_sparsity(&hess, i, j) {
290                        // Apply perturbation for both indices
291                        x_perturbed[i] += h[i];
292                        x_perturbed[j] += h[j];
293
294                        // f(x + h_i*e_i + h_j*e_j)
295                        let f_ij = func(&x_perturbed.view());
296
297                        // Mixed partial derivative
298                        let d2f_dxidxj =
299                            (f_ij - diag_evals[i] - diag_evals[j] + f0) / (h[i] * h[j]);
300
301                        // Update value
302                        update_sparse_value(&mut hess, i, j, d2f_dxidxj);
303
304                        // Reset
305                        x_perturbed[i] = x[i];
306                        x_perturbed[j] = x[j];
307                    }
308                }
309            }
310        }
311    }
312
313    Ok(hess)
314}
315
316/// Computes Hessian using 3-point finite differences (more accurate but more expensive)
317fn compute_hessian_3point<F>(
318    func: F,
319    x: &ArrayView1<f64>,
320    sparsity: &CsrArray<f64>,
321    options: &SparseFiniteDiffOptions,
322) -> Result<CsrArray<f64>, OptimizeError>
323where
324    F: Fn(&ArrayView1<f64>) -> f64 + Sync,
325{
326    let n = x.len();
327
328    // Determine column groups using a graph coloring algorithm
329    let groups = determine_column_groups(sparsity, None, None)?;
330
331    // Compute step sizes
332    let h = compute_step_sizes(x, options);
333
334    // Create result matrix with the same sparsity pattern as the upper triangle
335    let (rows, cols, _) = sparsity.find();
336    let (m, n_cols) = sparsity.shape();
337    let zeros = vec![0.0; rows.len()];
338    let mut hess =
339        CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n_cols), false)?;
340
341    // Create a mutable copy of x for perturbing
342    let mut x_perturbed = x.to_owned();
343
344    // Choose between parallel and serial execution
345    let parallel = options
346        .parallel
347        .as_ref()
348        .map(|p| p.num_workers.unwrap_or(1) > 1)
349        .unwrap_or(false);
350
351    // Compute diagonal elements using 3-point formula
352    let diag_evals: Vec<(f64, f64)> = if parallel {
353        (0..n)
354            .into_par_iter()
355            .map(|i| {
356                let mut x_local = x.to_owned();
357                x_local[i] += h[i];
358                let f_plus = func(&x_local.view());
359
360                x_local[i] = x[i] - h[i];
361                let f_minus = func(&x_local.view());
362
363                (f_plus, f_minus)
364            })
365            .collect()
366    } else {
367        let mut diag_vals = vec![(0.0, 0.0); n];
368        for i in 0..n {
369            x_perturbed[i] += h[i];
370            let f_plus = func(&x_perturbed.view());
371
372            x_perturbed[i] = x[i] - h[i];
373            let f_minus = func(&x_perturbed.view());
374
375            diag_vals[i] = (f_plus, f_minus);
376            x_perturbed[i] = x[i];
377        }
378        diag_vals
379    };
380
381    // Function value at x
382    let f0 = func(x);
383
384    // Set diagonal elements using 3-point central difference
385    for i in 0..n {
386        let (f_plus, f_minus) = diag_evals[i];
387        let d2f_dxi2 = (f_plus - 2.0 * f0 + f_minus) / (h[i] * h[i]);
388        update_sparse_value(&mut hess, i, i, d2f_dxi2);
389    }
390
391    // Compute off-diagonal elements using 3-point mixed derivatives
392    if parallel {
393        let derivatives: Vec<(usize, usize, f64)> = groups
394            .par_iter()
395            .flat_map(|group| {
396                let mut derivatives = Vec::new();
397                let mut x_local = x.to_owned();
398
399                for &j in group {
400                    // Only compute upper triangle
401                    for i in 0..j {
402                        if exists_in_sparsity(&hess, i, j) {
403                            // f(x + h_i*e_i + h_j*e_j)
404                            x_local[i] += h[i];
405                            x_local[j] += h[j];
406                            let f_pp = func(&x_local.view());
407
408                            // f(x + h_i*e_i - h_j*e_j)
409                            x_local[j] = x[j] - h[j];
410                            let f_pm = func(&x_local.view());
411
412                            // f(x - h_i*e_i + h_j*e_j)
413                            x_local[i] = x[i] - h[i];
414                            x_local[j] = x[j] + h[j];
415                            let f_mp = func(&x_local.view());
416
417                            // f(x - h_i*e_i - h_j*e_j)
418                            x_local[j] = x[j] - h[j];
419                            let f_mm = func(&x_local.view());
420
421                            // Mixed partial derivative using 3-point formula
422                            let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
423
424                            derivatives.push((i, j, d2f_dxidxj));
425
426                            // Reset
427                            x_local[i] = x[i];
428                            x_local[j] = x[j];
429                        }
430                    }
431                }
432
433                derivatives
434            })
435            .collect();
436
437        // Apply all derivatives
438        for (i, j, d2f_dxidxj) in derivatives {
439            if hess.set(i, j, d2f_dxidxj).is_err() {
440                // If this fails, just silently continue
441            }
442        }
443    } else {
444        for group in &groups {
445            for &j in group {
446                // Only compute upper triangle
447                for i in 0..j {
448                    if exists_in_sparsity(&hess, i, j) {
449                        // f(x + h_i*e_i + h_j*e_j)
450                        x_perturbed[i] += h[i];
451                        x_perturbed[j] += h[j];
452                        let f_pp = func(&x_perturbed.view());
453
454                        // f(x + h_i*e_i - h_j*e_j)
455                        x_perturbed[j] = x[j] - h[j];
456                        let f_pm = func(&x_perturbed.view());
457
458                        // f(x - h_i*e_i + h_j*e_j)
459                        x_perturbed[i] = x[i] - h[i];
460                        x_perturbed[j] = x[j] + h[j];
461                        let f_mp = func(&x_perturbed.view());
462
463                        // f(x - h_i*e_i - h_j*e_j)
464                        x_perturbed[j] = x[j] - h[j];
465                        let f_mm = func(&x_perturbed.view());
466
467                        // Mixed partial derivative using 3-point formula
468                        let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
469
470                        update_sparse_value(&mut hess, i, j, d2f_dxidxj);
471
472                        // Reset
473                        x_perturbed[i] = x[i];
474                        x_perturbed[j] = x[j];
475                    }
476                }
477            }
478        }
479    }
480
481    Ok(hess)
482}
483
484/// Computes Hessian using the complex step method (highly accurate)
485fn compute_hessian_complex_step<F>(
486    _func: F,
487    _x: &ArrayView1<f64>,
488    _sparsity: &CsrArray<f64>,
489    _options: &SparseFiniteDiffOptions,
490) -> Result<CsrArray<f64>, OptimizeError>
491where
492    F: Fn(&ArrayView1<f64>) -> f64 + Sync,
493{
494    // This is a placeholder implementation that would need to be expanded
495    // with the complex step algorithm. For now, we just return an error.
496    Err(OptimizeError::NotImplementedError(
497        "Complex step method for Hessian computation is not yet implemented".to_string(),
498    ))
499}
500
501/// Ensures a sparsity pattern is symmetric
502fn make_symmetric_sparsity(sparsity: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
503    let (m, n) = sparsity.shape();
504    if m != n {
505        return Err(OptimizeError::ValueError(
506            "Sparsity pattern must be square for Hessian computation".to_string(),
507        ));
508    }
509
510    // Convert to dense for simplicity
511    let dense = sparsity.to_array();
512    let dense_transposed = dense.t().to_owned();
513
514    // Create arrays for the triplets
515    let mut data = Vec::new();
516    let mut rows = Vec::new();
517    let mut cols = Vec::new();
518
519    // Combine the original and its transpose
520    for i in 0..n {
521        for j in 0..n {
522            if dense[[i, j]] > 0.0 || dense_transposed[[i, j]] > 0.0 {
523                rows.push(i);
524                cols.push(j);
525                data.push(1.0); // Binary sparsity pattern
526            }
527        }
528    }
529
530    // Create symmetric sparsity pattern
531    Ok(CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?)
532}
533
534/// Fills the lower triangle of a Hessian matrix based on the upper triangle
535fn fill_symmetric_hessian(upper: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
536    let (n, _) = upper.shape();
537    if n != upper.shape().1 {
538        return Err(OptimizeError::ValueError(
539            "Hessian matrix must be square".to_string(),
540        ));
541    }
542
543    // We need to create a new symmetric matrix from the upper triangular matrix
544
545    // Convert the upper triangle matrix to dense temporarily
546    let upper_dense = upper.to_array();
547
548    // Create arrays for the triplets
549    let mut data = Vec::new();
550    let mut rows = Vec::new();
551    let mut cols = Vec::new();
552
553    // Collect all non-zero entries including the symmetric counterparts
554    for i in 0..n {
555        for j in 0..n {
556            let value = upper_dense[[i, j]];
557            if value != 0.0 {
558                // Add the original element
559                rows.push(i);
560                cols.push(j);
561                data.push(value);
562
563                // If not on diagonal, add the symmetric element
564                if i != j {
565                    rows.push(j);
566                    cols.push(i);
567                    data.push(value);
568                }
569            }
570        }
571    }
572
573    // Create new symmetric matrix
574    let full = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
575
576    Ok(full)
577}