1use ndarray::{Array1, ArrayView1};
7use scirs2_core::parallel_ops::*;
8use scirs2_sparse::{csr_array::CsrArray, sparray::SparseArray};
9
10use super::coloring::determine_column_groups;
11use super::finite_diff::{compute_step_sizes, SparseFiniteDiffOptions};
12use crate::error::OptimizeError;
13
14#[allow(dead_code)]
16fn update_sparse_value(matrix: &mut CsrArray<f64>, row: usize, col: usize, value: f64) {
17 if matrix.get(row, col) != 0.0 && matrix.set(row, col, value).is_err() {
19 }
21}
22
23#[allow(dead_code)]
25fn exists_in_sparsity(matrix: &CsrArray<f64>, row: usize, col: usize) -> bool {
26 matrix.get(row, col) != 0.0
27}
28
29#[allow(dead_code)]
46pub fn sparse_hessian<F, G>(
47 func: F,
48 grad: Option<G>,
49 x: &ArrayView1<f64>,
50 f0: Option<f64>,
51 g0: Option<&Array1<f64>>,
52 sparsity_pattern: Option<&CsrArray<f64>>,
53 options: Option<SparseFiniteDiffOptions>,
54) -> Result<CsrArray<f64>, OptimizeError>
55where
56 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
57 G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
58{
59 let options = options.unwrap_or_default();
60 let n = x.len();
61
62 if let Some(gradient_fn) = grad {
65 return compute_hessian_from_gradient(gradient_fn, x, g0, sparsity_pattern, &options);
66 }
67
68 let sparsity_owned: CsrArray<f64>;
70 let sparsity = match sparsity_pattern {
71 Some(p) => {
72 if p.shape().0 != n || p.shape().1 != n {
74 return Err(OptimizeError::ValueError(format!(
75 "Sparsity _pattern shape {:?} does not match input dimension {}",
76 p.shape(),
77 n
78 )));
79 }
80 p
81 }
82 None => {
83 let mut data = Vec::with_capacity(n * n);
85 let mut rows = Vec::with_capacity(n * n);
86 let mut cols = Vec::with_capacity(n * n);
87
88 for i in 0..n {
89 for j in 0..n {
90 data.push(1.0);
91 rows.push(i);
92 cols.push(j);
93 }
94 }
95
96 sparsity_owned = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
97 &sparsity_owned
98 }
99 };
100
101 let symmetric_sparsity = make_symmetric_sparsity(sparsity)?;
105
106 let result = match options.method.as_str() {
108 "2-point" => {
109 let f0_val = f0.unwrap_or_else(|| func(x));
110 compute_hessian_2point(func, x, f0_val, &symmetric_sparsity, &options)
111 }
112 "3-point" => compute_hessian_3point(func, x, &symmetric_sparsity, &options),
113 "cs" => compute_hessian_complex_step(func, x, &symmetric_sparsity, &options),
114 _ => Err(OptimizeError::ValueError(format!(
115 "Unknown method: {}. Valid options are '2-point', '3-point', and 'cs'",
116 options.method
117 ))),
118 }?;
119
120 fill_symmetric_hessian(&result)
122}
123
124#[allow(dead_code)]
126fn compute_hessian_from_gradient<G>(
127 grad_fn: G,
128 x: &ArrayView1<f64>,
129 g0: Option<&Array1<f64>>,
130 sparsity_pattern: Option<&CsrArray<f64>>,
131 options: &SparseFiniteDiffOptions,
132) -> Result<CsrArray<f64>, OptimizeError>
133where
134 G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
135{
136 let _n = x.len();
137
138 let g0_owned: Array1<f64>;
140 let g0_ref = match g0 {
141 Some(g) => g,
142 None => {
143 g0_owned = grad_fn(x);
144 &g0_owned
145 }
146 };
147
148 let jac_options = SparseFiniteDiffOptions {
151 method: options.method.clone(),
152 rel_step: options.rel_step,
153 abs_step: options.abs_step,
154 bounds: options.bounds.clone(),
155 parallel: options.parallel.clone(),
156 seed: options.seed,
157 max_group_size: options.max_group_size,
158 };
159
160 let hessian = super::jacobian::sparse_jacobian(
162 grad_fn,
163 x,
164 Some(g0_ref),
165 sparsity_pattern,
166 Some(jac_options),
167 )?;
168
169 fill_symmetric_hessian(&hessian)
171}
172
173#[allow(dead_code)]
175fn compute_hessian_2point<F>(
176 func: F,
177 x: &ArrayView1<f64>,
178 f0: f64,
179 sparsity: &CsrArray<f64>,
180 options: &SparseFiniteDiffOptions,
181) -> Result<CsrArray<f64>, OptimizeError>
182where
183 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
184{
185 let _n = x.len();
186
187 let groups = determine_column_groups(sparsity, None, None)?;
189
190 let h = compute_step_sizes(x, options);
192
193 let (rows, cols, _) = sparsity.find();
195 let (m, n) = sparsity.shape();
196 let zeros = vec![0.0; rows.len()];
197 let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n), false)?;
198
199 let mut x_perturbed = x.to_owned();
201
202 let parallel = options
204 .parallel
205 .as_ref()
206 .map(|p| p.num_workers.unwrap_or(1) > 1)
207 .unwrap_or(false);
208
209 let diag_evals: Vec<f64> = if parallel {
211 (0..n)
212 .into_par_iter()
213 .map(|i| {
214 let mut x_local = x.to_owned();
215 x_local[i] += h[i];
216 func(&x_local.view())
217 })
218 .collect()
219 } else {
220 let mut diag_vals = vec![0.0; n];
221 for i in 0..n {
222 x_perturbed[i] += h[i];
223 diag_vals[i] = func(&x_perturbed.view());
224 x_perturbed[i] = x[i];
225 }
226 diag_vals
227 };
228
229 for i in 0..n {
231 let d2f_dxi2 = (diag_evals[i] - 2.0 * f0 + diag_evals[i]) / (h[i] * h[i]);
233
234 update_sparse_value(&mut hess, i, i, d2f_dxi2);
236 }
237
238 if parallel {
240 let derivatives: Vec<(usize, usize, f64)> = groups
242 .par_iter()
243 .flat_map(|group| {
244 let mut derivatives = Vec::new();
245 let mut x_local = x.to_owned();
246
247 for &j in group {
248 for i in 0..j {
250 if exists_in_sparsity(&hess, i, j) {
251 x_local[i] += h[i];
253 x_local[j] += h[j];
254
255 let f_ij = func(&x_local.view());
257
258 x_local[j] = x[j];
260 let f_i = diag_evals[i];
261
262 x_local[i] = x[i];
264 x_local[j] += h[j];
265 let f_j = diag_evals[j];
266
267 let d2f_dxidxj = (f_ij - f_i - f_j + f0) / (h[i] * h[j]);
269
270 derivatives.push((i, j, d2f_dxidxj));
272
273 x_local[j] = x[j];
275 }
276 }
277 }
278
279 derivatives
280 })
281 .collect();
282
283 for (i, j, d2f_dxidxj) in derivatives {
285 if hess.set(i, j, d2f_dxidxj).is_err() {
286 }
288 }
289 } else {
290 for group in &groups {
291 for &j in group {
292 for i in 0..j {
294 if exists_in_sparsity(&hess, i, j) {
295 x_perturbed[i] += h[i];
297 x_perturbed[j] += h[j];
298
299 let f_ij = func(&x_perturbed.view());
301
302 let d2f_dxidxj =
304 (f_ij - diag_evals[i] - diag_evals[j] + f0) / (h[i] * h[j]);
305
306 update_sparse_value(&mut hess, i, j, d2f_dxidxj);
308
309 x_perturbed[i] = x[i];
311 x_perturbed[j] = x[j];
312 }
313 }
314 }
315 }
316 }
317
318 Ok(hess)
319}
320
321#[allow(dead_code)]
323fn compute_hessian_3point<F>(
324 func: F,
325 x: &ArrayView1<f64>,
326 sparsity: &CsrArray<f64>,
327 options: &SparseFiniteDiffOptions,
328) -> Result<CsrArray<f64>, OptimizeError>
329where
330 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
331{
332 let n = x.len();
333
334 let groups = determine_column_groups(sparsity, None, None)?;
336
337 let h = compute_step_sizes(x, options);
339
340 let (rows, cols, _) = sparsity.find();
342 let (m, n_cols) = sparsity.shape();
343 let zeros = vec![0.0; rows.len()];
344 let mut hess =
345 CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n_cols), false)?;
346
347 let mut x_perturbed = x.to_owned();
349
350 let parallel = options
352 .parallel
353 .as_ref()
354 .map(|p| p.num_workers.unwrap_or(1) > 1)
355 .unwrap_or(false);
356
357 let diag_evals: Vec<(f64, f64)> = if parallel {
359 (0..n)
360 .into_par_iter()
361 .map(|i| {
362 let mut x_local = x.to_owned();
363 x_local[i] += h[i];
364 let f_plus = func(&x_local.view());
365
366 x_local[i] = x[i] - h[i];
367 let f_minus = func(&x_local.view());
368
369 (f_plus, f_minus)
370 })
371 .collect()
372 } else {
373 let mut diag_vals = vec![(0.0, 0.0); n];
374 for i in 0..n {
375 x_perturbed[i] += h[i];
376 let f_plus = func(&x_perturbed.view());
377
378 x_perturbed[i] = x[i] - h[i];
379 let f_minus = func(&x_perturbed.view());
380
381 diag_vals[i] = (f_plus, f_minus);
382 x_perturbed[i] = x[i];
383 }
384 diag_vals
385 };
386
387 let f0 = func(x);
389
390 for i in 0..n {
392 let (f_plus, f_minus) = diag_evals[i];
393 let d2f_dxi2 = (f_plus - 2.0 * f0 + f_minus) / (h[i] * h[i]);
394 update_sparse_value(&mut hess, i, i, d2f_dxi2);
395 }
396
397 if parallel {
399 let derivatives: Vec<(usize, usize, f64)> = groups
400 .par_iter()
401 .flat_map(|group| {
402 let mut derivatives = Vec::new();
403 let mut x_local = x.to_owned();
404
405 for &j in group {
406 for i in 0..j {
408 if exists_in_sparsity(&hess, i, j) {
409 x_local[i] += h[i];
411 x_local[j] += h[j];
412 let f_pp = func(&x_local.view());
413
414 x_local[j] = x[j] - h[j];
416 let f_pm = func(&x_local.view());
417
418 x_local[i] = x[i] - h[i];
420 x_local[j] = x[j] + h[j];
421 let f_mp = func(&x_local.view());
422
423 x_local[j] = x[j] - h[j];
425 let f_mm = func(&x_local.view());
426
427 let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
429
430 derivatives.push((i, j, d2f_dxidxj));
431
432 x_local[i] = x[i];
434 x_local[j] = x[j];
435 }
436 }
437 }
438
439 derivatives
440 })
441 .collect();
442
443 for (i, j, d2f_dxidxj) in derivatives {
445 if hess.set(i, j, d2f_dxidxj).is_err() {
446 }
448 }
449 } else {
450 for group in &groups {
451 for &j in group {
452 for i in 0..j {
454 if exists_in_sparsity(&hess, i, j) {
455 x_perturbed[i] += h[i];
457 x_perturbed[j] += h[j];
458 let f_pp = func(&x_perturbed.view());
459
460 x_perturbed[j] = x[j] - h[j];
462 let f_pm = func(&x_perturbed.view());
463
464 x_perturbed[i] = x[i] - h[i];
466 x_perturbed[j] = x[j] + h[j];
467 let f_mp = func(&x_perturbed.view());
468
469 x_perturbed[j] = x[j] - h[j];
471 let f_mm = func(&x_perturbed.view());
472
473 let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
475
476 update_sparse_value(&mut hess, i, j, d2f_dxidxj);
477
478 x_perturbed[i] = x[i];
480 x_perturbed[j] = x[j];
481 }
482 }
483 }
484 }
485 }
486
487 Ok(hess)
488}
489
490#[allow(dead_code)]
497fn compute_hessian_complex_step<F>(
498 func: F,
499 x: &ArrayView1<f64>,
500 sparsity: &CsrArray<f64>,
501 options: &SparseFiniteDiffOptions,
502) -> Result<CsrArray<f64>, OptimizeError>
503where
504 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
505{
506 let n = x.len();
507
508 let h = options.abs_step.unwrap_or(1e-20);
510
511 let groups = determine_column_groups(sparsity, None, None)?;
513
514 let (rows, cols, _) = sparsity.find();
516 let zeros = vec![0.0; rows.len()];
517 let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (n, n), false)?;
518
519 let parallel = options
521 .parallel
522 .as_ref()
523 .map(|p| p.num_workers.unwrap_or(1) > 1)
524 .unwrap_or(false);
525
526 let _f0 = func(x);
528
529 if parallel {
530 let derivatives: Vec<(usize, usize, f64)> = groups
532 .par_iter()
533 .flat_map(|group| {
534 let mut derivatives = Vec::new();
535
536 for &j in group {
537 if exists_in_sparsity(&hess, j, j) {
539 let d2f_dxj2 = compute_hessian_diagonal_complex_step(&func, x, j, h);
540 derivatives.push((j, j, d2f_dxj2));
541 }
542
543 for i in 0..j {
545 if exists_in_sparsity(&hess, i, j) {
546 let d2f_dxidxj = compute_hessian_mixed_complex_step(&func, x, i, j, h);
547 derivatives.push((i, j, d2f_dxidxj));
548 }
549 }
550 }
551
552 derivatives
553 })
554 .collect();
555
556 for (i, j, derivative) in derivatives {
558 if hess.set(i, j, derivative).is_err() {
559 }
561 }
562 } else {
563 for group in &groups {
565 for &j in group {
566 if exists_in_sparsity(&hess, j, j) {
568 let d2f_dxj2 = compute_hessian_diagonal_complex_step(&func, x, j, h);
569 update_sparse_value(&mut hess, j, j, d2f_dxj2);
570 }
571
572 for i in 0..j {
574 if exists_in_sparsity(&hess, i, j) {
575 let d2f_dxidxj = compute_hessian_mixed_complex_step(&func, x, i, j, h);
576 update_sparse_value(&mut hess, i, j, d2f_dxidxj);
577 }
578 }
579 }
580 }
581 }
582
583 Ok(hess)
584}
585
586#[allow(dead_code)]
588fn compute_hessian_diagonal_complex_step<F>(func: &F, x: &ArrayView1<f64>, i: usize, h: f64) -> f64
589where
590 F: Fn(&ArrayView1<f64>) -> f64,
591{
592 let mut x_plus = x.to_owned();
596 let mut x_minus = x.to_owned();
597 let mut x_plus2 = x.to_owned();
598 let mut x_minus2 = x.to_owned();
599
600 x_plus[i] += h;
601 x_minus[i] -= h;
602 x_plus2[i] += 2.0 * h;
603 x_minus2[i] -= 2.0 * h;
604
605 let f_plus = func(&x_plus.view());
606 let f_minus = func(&x_minus.view());
607 let f_plus2 = func(&x_plus2.view());
608 let f_minus2 = func(&x_minus2.view());
609 let f0 = func(x);
610
611 (-f_plus2 + 16.0 * f_plus - 30.0 * f0 + 16.0 * f_minus - f_minus2) / (12.0 * h * h)
614}
615
616#[allow(dead_code)]
618fn compute_hessian_mixed_complex_step<F>(
619 func: &F,
620 x: &ArrayView1<f64>,
621 i: usize,
622 j: usize,
623 h: f64,
624) -> f64
625where
626 F: Fn(&ArrayView1<f64>) -> f64,
627{
628 let mut x_pp = x.to_owned();
633 x_pp[i] += h;
634 x_pp[j] += h;
635 let f_pp = func(&x_pp.view());
636
637 let mut x_pm = x.to_owned();
639 x_pm[i] += h;
640 x_pm[j] -= h;
641 let f_pm = func(&x_pm.view());
642
643 let mut x_mp = x.to_owned();
645 x_mp[i] -= h;
646 x_mp[j] += h;
647 let f_mp = func(&x_mp.view());
648
649 let mut x_mm = x.to_owned();
651 x_mm[i] -= h;
652 x_mm[j] -= h;
653 let f_mm = func(&x_mm.view());
654
655 (f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h)
658}
659
660#[allow(dead_code)]
662fn make_symmetric_sparsity(sparsity: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
663 let (m, n) = sparsity.shape();
664 if m != n {
665 return Err(OptimizeError::ValueError(
666 "Sparsity pattern must be square for Hessian computation".to_string(),
667 ));
668 }
669
670 let dense = sparsity.to_array();
672 let dense_transposed = dense.t().to_owned();
673
674 let mut data = Vec::new();
676 let mut rows = Vec::new();
677 let mut cols = Vec::new();
678
679 for i in 0..n {
681 for j in 0..n {
682 if dense[[i, j]] > 0.0 || dense_transposed[[i, j]] > 0.0 {
683 rows.push(i);
684 cols.push(j);
685 data.push(1.0); }
687 }
688 }
689
690 Ok(CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?)
692}
693
694#[allow(dead_code)]
696fn fill_symmetric_hessian(upper: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
697 let (n, _) = upper.shape();
698 if n != upper.shape().1 {
699 return Err(OptimizeError::ValueError(
700 "Hessian matrix must be square".to_string(),
701 ));
702 }
703
704 let upper_dense = upper.to_array();
708
709 let mut data = Vec::new();
711 let mut rows = Vec::new();
712 let mut cols = Vec::new();
713
714 for i in 0..n {
716 for j in 0..n {
717 let value = upper_dense[[i, j]];
718 if value != 0.0 {
719 rows.push(i);
721 cols.push(j);
722 data.push(value);
723
724 if i != j {
726 rows.push(j);
727 cols.push(i);
728 data.push(value);
729 }
730 }
731 }
732 }
733
734 let full = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
736
737 Ok(full)
738}