1use ndarray::{Array1, ArrayView1};
7use scirs2_core::parallel_ops::*;
8use scirs2_sparse::{csr_array::CsrArray, sparray::SparseArray};
9
10use super::coloring::determine_column_groups;
11use super::finite_diff::{compute_step_sizes, SparseFiniteDiffOptions};
12use crate::error::OptimizeError;
13
14fn update_sparse_value(matrix: &mut CsrArray<f64>, row: usize, col: usize, value: f64) {
16 if matrix.get(row, col) != 0.0 && matrix.set(row, col, value).is_err() {
18 }
20}
21
22fn exists_in_sparsity(matrix: &CsrArray<f64>, row: usize, col: usize) -> bool {
24 matrix.get(row, col) != 0.0
25}
26
27pub fn sparse_hessian<F, G>(
44 func: F,
45 grad: Option<G>,
46 x: &ArrayView1<f64>,
47 f0: Option<f64>,
48 g0: Option<&Array1<f64>>,
49 sparsity_pattern: Option<&CsrArray<f64>>,
50 options: Option<SparseFiniteDiffOptions>,
51) -> Result<CsrArray<f64>, OptimizeError>
52where
53 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
54 G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
55{
56 let options = options.unwrap_or_default();
57 let n = x.len();
58
59 if let Some(gradient_fn) = grad {
62 return compute_hessian_from_gradient(gradient_fn, x, g0, sparsity_pattern, &options);
63 }
64
65 let sparsity_owned: CsrArray<f64>;
67 let sparsity = match sparsity_pattern {
68 Some(p) => {
69 if p.shape().0 != n || p.shape().1 != n {
71 return Err(OptimizeError::ValueError(format!(
72 "Sparsity pattern shape {:?} does not match input dimension {}",
73 p.shape(),
74 n
75 )));
76 }
77 p
78 }
79 None => {
80 let mut data = Vec::with_capacity(n * n);
82 let mut rows = Vec::with_capacity(n * n);
83 let mut cols = Vec::with_capacity(n * n);
84
85 for i in 0..n {
86 for j in 0..n {
87 data.push(1.0);
88 rows.push(i);
89 cols.push(j);
90 }
91 }
92
93 sparsity_owned = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
94 &sparsity_owned
95 }
96 };
97
98 let symmetric_sparsity = make_symmetric_sparsity(sparsity)?;
102
103 let result = match options.method.as_str() {
105 "2-point" => {
106 let f0_val = f0.unwrap_or_else(|| func(x));
107 compute_hessian_2point(func, x, f0_val, &symmetric_sparsity, &options)
108 }
109 "3-point" => compute_hessian_3point(func, x, &symmetric_sparsity, &options),
110 "cs" => compute_hessian_complex_step(func, x, &symmetric_sparsity, &options),
111 _ => Err(OptimizeError::ValueError(format!(
112 "Unknown method: {}. Valid options are '2-point', '3-point', and 'cs'",
113 options.method
114 ))),
115 }?;
116
117 fill_symmetric_hessian(&result)
119}
120
121fn compute_hessian_from_gradient<G>(
123 grad_fn: G,
124 x: &ArrayView1<f64>,
125 g0: Option<&Array1<f64>>,
126 sparsity_pattern: Option<&CsrArray<f64>>,
127 options: &SparseFiniteDiffOptions,
128) -> Result<CsrArray<f64>, OptimizeError>
129where
130 G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
131{
132 let _n = x.len();
133
134 let g0_owned: Array1<f64>;
136 let g0_ref = match g0 {
137 Some(g) => g,
138 None => {
139 g0_owned = grad_fn(x);
140 &g0_owned
141 }
142 };
143
144 let jac_options = SparseFiniteDiffOptions {
147 method: options.method.clone(),
148 rel_step: options.rel_step,
149 abs_step: options.abs_step,
150 bounds: options.bounds.clone(),
151 parallel: options.parallel.clone(),
152 seed: options.seed,
153 max_group_size: options.max_group_size,
154 };
155
156 let hessian = super::jacobian::sparse_jacobian(
158 grad_fn,
159 x,
160 Some(g0_ref),
161 sparsity_pattern,
162 Some(jac_options),
163 )?;
164
165 fill_symmetric_hessian(&hessian)
167}
168
169fn compute_hessian_2point<F>(
171 func: F,
172 x: &ArrayView1<f64>,
173 f0: f64,
174 sparsity: &CsrArray<f64>,
175 options: &SparseFiniteDiffOptions,
176) -> Result<CsrArray<f64>, OptimizeError>
177where
178 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
179{
180 let _n = x.len();
181
182 let groups = determine_column_groups(sparsity, None, None)?;
184
185 let h = compute_step_sizes(x, options);
187
188 let (rows, cols, _) = sparsity.find();
190 let (m, n) = sparsity.shape();
191 let zeros = vec![0.0; rows.len()];
192 let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n), false)?;
193
194 let mut x_perturbed = x.to_owned();
196
197 let parallel = options
199 .parallel
200 .as_ref()
201 .map(|p| p.num_workers.unwrap_or(1) > 1)
202 .unwrap_or(false);
203
204 let diag_evals: Vec<f64> = if parallel {
206 (0..n)
207 .into_par_iter()
208 .map(|i| {
209 let mut x_local = x.to_owned();
210 x_local[i] += h[i];
211 func(&x_local.view())
212 })
213 .collect()
214 } else {
215 let mut diag_vals = vec![0.0; n];
216 for i in 0..n {
217 x_perturbed[i] += h[i];
218 diag_vals[i] = func(&x_perturbed.view());
219 x_perturbed[i] = x[i];
220 }
221 diag_vals
222 };
223
224 for i in 0..n {
226 let d2f_dxi2 = (diag_evals[i] - 2.0 * f0 + diag_evals[i]) / (h[i] * h[i]);
228
229 update_sparse_value(&mut hess, i, i, d2f_dxi2);
231 }
232
233 if parallel {
235 let derivatives: Vec<(usize, usize, f64)> = groups
237 .par_iter()
238 .flat_map(|group| {
239 let mut derivatives = Vec::new();
240 let mut x_local = x.to_owned();
241
242 for &j in group {
243 for i in 0..j {
245 if exists_in_sparsity(&hess, i, j) {
246 x_local[i] += h[i];
248 x_local[j] += h[j];
249
250 let f_ij = func(&x_local.view());
252
253 x_local[j] = x[j];
255 let f_i = diag_evals[i];
256
257 x_local[i] = x[i];
259 x_local[j] += h[j];
260 let f_j = diag_evals[j];
261
262 let d2f_dxidxj = (f_ij - f_i - f_j + f0) / (h[i] * h[j]);
264
265 derivatives.push((i, j, d2f_dxidxj));
267
268 x_local[j] = x[j];
270 }
271 }
272 }
273
274 derivatives
275 })
276 .collect();
277
278 for (i, j, d2f_dxidxj) in derivatives {
280 if hess.set(i, j, d2f_dxidxj).is_err() {
281 }
283 }
284 } else {
285 for group in &groups {
286 for &j in group {
287 for i in 0..j {
289 if exists_in_sparsity(&hess, i, j) {
290 x_perturbed[i] += h[i];
292 x_perturbed[j] += h[j];
293
294 let f_ij = func(&x_perturbed.view());
296
297 let d2f_dxidxj =
299 (f_ij - diag_evals[i] - diag_evals[j] + f0) / (h[i] * h[j]);
300
301 update_sparse_value(&mut hess, i, j, d2f_dxidxj);
303
304 x_perturbed[i] = x[i];
306 x_perturbed[j] = x[j];
307 }
308 }
309 }
310 }
311 }
312
313 Ok(hess)
314}
315
316fn compute_hessian_3point<F>(
318 func: F,
319 x: &ArrayView1<f64>,
320 sparsity: &CsrArray<f64>,
321 options: &SparseFiniteDiffOptions,
322) -> Result<CsrArray<f64>, OptimizeError>
323where
324 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
325{
326 let n = x.len();
327
328 let groups = determine_column_groups(sparsity, None, None)?;
330
331 let h = compute_step_sizes(x, options);
333
334 let (rows, cols, _) = sparsity.find();
336 let (m, n_cols) = sparsity.shape();
337 let zeros = vec![0.0; rows.len()];
338 let mut hess =
339 CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n_cols), false)?;
340
341 let mut x_perturbed = x.to_owned();
343
344 let parallel = options
346 .parallel
347 .as_ref()
348 .map(|p| p.num_workers.unwrap_or(1) > 1)
349 .unwrap_or(false);
350
351 let diag_evals: Vec<(f64, f64)> = if parallel {
353 (0..n)
354 .into_par_iter()
355 .map(|i| {
356 let mut x_local = x.to_owned();
357 x_local[i] += h[i];
358 let f_plus = func(&x_local.view());
359
360 x_local[i] = x[i] - h[i];
361 let f_minus = func(&x_local.view());
362
363 (f_plus, f_minus)
364 })
365 .collect()
366 } else {
367 let mut diag_vals = vec![(0.0, 0.0); n];
368 for i in 0..n {
369 x_perturbed[i] += h[i];
370 let f_plus = func(&x_perturbed.view());
371
372 x_perturbed[i] = x[i] - h[i];
373 let f_minus = func(&x_perturbed.view());
374
375 diag_vals[i] = (f_plus, f_minus);
376 x_perturbed[i] = x[i];
377 }
378 diag_vals
379 };
380
381 let f0 = func(x);
383
384 for i in 0..n {
386 let (f_plus, f_minus) = diag_evals[i];
387 let d2f_dxi2 = (f_plus - 2.0 * f0 + f_minus) / (h[i] * h[i]);
388 update_sparse_value(&mut hess, i, i, d2f_dxi2);
389 }
390
391 if parallel {
393 let derivatives: Vec<(usize, usize, f64)> = groups
394 .par_iter()
395 .flat_map(|group| {
396 let mut derivatives = Vec::new();
397 let mut x_local = x.to_owned();
398
399 for &j in group {
400 for i in 0..j {
402 if exists_in_sparsity(&hess, i, j) {
403 x_local[i] += h[i];
405 x_local[j] += h[j];
406 let f_pp = func(&x_local.view());
407
408 x_local[j] = x[j] - h[j];
410 let f_pm = func(&x_local.view());
411
412 x_local[i] = x[i] - h[i];
414 x_local[j] = x[j] + h[j];
415 let f_mp = func(&x_local.view());
416
417 x_local[j] = x[j] - h[j];
419 let f_mm = func(&x_local.view());
420
421 let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
423
424 derivatives.push((i, j, d2f_dxidxj));
425
426 x_local[i] = x[i];
428 x_local[j] = x[j];
429 }
430 }
431 }
432
433 derivatives
434 })
435 .collect();
436
437 for (i, j, d2f_dxidxj) in derivatives {
439 if hess.set(i, j, d2f_dxidxj).is_err() {
440 }
442 }
443 } else {
444 for group in &groups {
445 for &j in group {
446 for i in 0..j {
448 if exists_in_sparsity(&hess, i, j) {
449 x_perturbed[i] += h[i];
451 x_perturbed[j] += h[j];
452 let f_pp = func(&x_perturbed.view());
453
454 x_perturbed[j] = x[j] - h[j];
456 let f_pm = func(&x_perturbed.view());
457
458 x_perturbed[i] = x[i] - h[i];
460 x_perturbed[j] = x[j] + h[j];
461 let f_mp = func(&x_perturbed.view());
462
463 x_perturbed[j] = x[j] - h[j];
465 let f_mm = func(&x_perturbed.view());
466
467 let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
469
470 update_sparse_value(&mut hess, i, j, d2f_dxidxj);
471
472 x_perturbed[i] = x[i];
474 x_perturbed[j] = x[j];
475 }
476 }
477 }
478 }
479 }
480
481 Ok(hess)
482}
483
484fn compute_hessian_complex_step<F>(
486 _func: F,
487 _x: &ArrayView1<f64>,
488 _sparsity: &CsrArray<f64>,
489 _options: &SparseFiniteDiffOptions,
490) -> Result<CsrArray<f64>, OptimizeError>
491where
492 F: Fn(&ArrayView1<f64>) -> f64 + Sync,
493{
494 Err(OptimizeError::NotImplementedError(
497 "Complex step method for Hessian computation is not yet implemented".to_string(),
498 ))
499}
500
501fn make_symmetric_sparsity(sparsity: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
503 let (m, n) = sparsity.shape();
504 if m != n {
505 return Err(OptimizeError::ValueError(
506 "Sparsity pattern must be square for Hessian computation".to_string(),
507 ));
508 }
509
510 let dense = sparsity.to_array();
512 let dense_transposed = dense.t().to_owned();
513
514 let mut data = Vec::new();
516 let mut rows = Vec::new();
517 let mut cols = Vec::new();
518
519 for i in 0..n {
521 for j in 0..n {
522 if dense[[i, j]] > 0.0 || dense_transposed[[i, j]] > 0.0 {
523 rows.push(i);
524 cols.push(j);
525 data.push(1.0); }
527 }
528 }
529
530 Ok(CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?)
532}
533
534fn fill_symmetric_hessian(upper: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
536 let (n, _) = upper.shape();
537 if n != upper.shape().1 {
538 return Err(OptimizeError::ValueError(
539 "Hessian matrix must be square".to_string(),
540 ));
541 }
542
543 let upper_dense = upper.to_array();
547
548 let mut data = Vec::new();
550 let mut rows = Vec::new();
551 let mut cols = Vec::new();
552
553 for i in 0..n {
555 for j in 0..n {
556 let value = upper_dense[[i, j]];
557 if value != 0.0 {
558 rows.push(i);
560 cols.push(j);
561 data.push(value);
562
563 if i != j {
565 rows.push(j);
566 cols.push(i);
567 data.push(value);
568 }
569 }
570 }
571 }
572
573 let full = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
575
576 Ok(full)
577}