scirs2_optimize/unconstrained/
memory_efficient.rs1use crate::error::OptimizeError;
8use crate::unconstrained::line_search::backtracking_line_search;
9use crate::unconstrained::result::OptimizeResult;
10use crate::unconstrained::utils::check_convergence;
11use crate::unconstrained::Options;
12use scirs2_core::ndarray::{Array1, ArrayView1};
13use std::collections::VecDeque;
14
15#[derive(Debug, Clone)]
17pub struct MemoryOptions {
18 pub base_options: Options,
20 pub max_memory_bytes: usize,
22 pub chunk_size: usize,
24 pub max_history: usize,
26 pub use_memory_pool: bool,
28 pub use_out_of_core: bool,
30 pub temp_dir: Option<std::path::PathBuf>,
32}
33
34impl Default for MemoryOptions {
35 fn default() -> Self {
36 Self {
37 base_options: Options::default(),
38 max_memory_bytes: 0, chunk_size: 1024, max_history: 10, use_memory_pool: true,
42 use_out_of_core: false,
43 temp_dir: None,
44 }
45 }
46}
47
48struct MemoryPool {
50 array_pool: VecDeque<Array1<f64>>,
51 max_pool_size: usize,
52}
53
54impl MemoryPool {
55 fn new(max_size: usize) -> Self {
56 Self {
57 array_pool: VecDeque::new(),
58 max_pool_size: max_size,
59 }
60 }
61
62 fn get_array(&mut self, size: usize) -> Array1<f64> {
63 for i in 0..self.array_pool.len() {
65 if self.array_pool[i].len() == size {
66 return self.array_pool.remove(i).unwrap();
67 }
68 }
69 Array1::zeros(size)
71 }
72
73 fn return_array(&mut self, mut array: Array1<f64>) {
74 if self.array_pool.len() < self.max_pool_size {
75 array.fill(0.0);
77 self.array_pool.push_back(array);
78 }
79 }
81}
82
83struct StreamingGradient {
85 chunk_size: usize,
86 eps: f64,
87}
88
89impl StreamingGradient {
90 fn new(chunk_size: usize, eps: f64) -> Self {
91 Self { chunk_size, eps }
92 }
93
94 fn compute<F, S>(&self, fun: &mut F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
96 where
97 F: FnMut(&ArrayView1<f64>) -> S,
98 S: Into<f64>,
99 {
100 let n = x.len();
101 let mut grad = Array1::zeros(n);
102 let f0 = fun(x).into();
103
104 let mut x_pert = x.to_owned();
105
106 for chunk_start in (0..n).step_by(self.chunk_size) {
108 let chunk_end = std::cmp::min(chunk_start + self.chunk_size, n);
109
110 for i in chunk_start..chunk_end {
111 let h = self.eps * (1.0 + x[i].abs());
112 x_pert[i] = x[i] + h;
113
114 let f_plus = fun(&x_pert.view()).into();
115
116 if !f_plus.is_finite() {
117 return Err(OptimizeError::ComputationError(
118 "Function returned non-finite value during gradient computation"
119 .to_string(),
120 ));
121 }
122
123 grad[i] = (f_plus - f0) / h;
124 x_pert[i] = x[i]; }
126
127 }
130
131 Ok(grad)
132 }
133}
134
135#[allow(dead_code)]
137pub fn minimize_memory_efficient_lbfgs<F, S>(
138 mut fun: F,
139 x0: Array1<f64>,
140 options: &MemoryOptions,
141) -> Result<OptimizeResult<S>, OptimizeError>
142where
143 F: FnMut(&ArrayView1<f64>) -> S + Clone,
144 S: Into<f64> + Clone,
145{
146 let n = x0.len();
147 let base_opts = &options.base_options;
148
149 let mut memory_pool = if options.use_memory_pool {
151 Some(MemoryPool::new(options.max_history * 2))
152 } else {
153 None
154 };
155
156 let estimated_memory = estimate_memory_usage(n, options.max_history);
158 if options.max_memory_bytes > 0 && estimated_memory > options.max_memory_bytes {
159 return Err(OptimizeError::ValueError(format!(
160 "Estimated memory usage ({} bytes) exceeds limit ({} bytes). Consider reducing max_history or chunk_size.",
161 estimated_memory, options.max_memory_bytes
162 )));
163 }
164
165 let mut x = x0.to_owned();
167 let bounds = base_opts.bounds.as_ref();
168
169 if let Some(bounds) = bounds {
171 bounds.project(x.as_slice_mut().unwrap());
172 }
173
174 let mut f = fun(&x.view()).into();
175
176 let streaming_grad = StreamingGradient::new(options.chunk_size, base_opts.eps);
178
179 let mut g = streaming_grad.compute(&mut fun, &x.view())?;
181
182 let mut s_history: VecDeque<Array1<f64>> = VecDeque::new();
184 let mut y_history: VecDeque<Array1<f64>> = VecDeque::new();
185
186 let mut iter = 0;
188 let mut nfev = 1 + n.div_ceil(options.chunk_size); while iter < base_opts.max_iter {
192 if g.mapv(|gi| gi.abs()).sum() < base_opts.gtol {
194 break;
195 }
196
197 let mut p = if s_history.is_empty() {
199 get_array_from_pool(&mut memory_pool, n, |_| -&g)
201 } else {
202 compute_lbfgs_direction_memory_efficient(&g, &s_history, &y_history, &mut memory_pool)?
203 };
204
205 if let Some(bounds) = bounds {
207 for i in 0..n {
208 let mut can_decrease = true;
209 let mut can_increase = true;
210
211 if let Some(lb) = bounds.lower[i] {
213 if x[i] <= lb + base_opts.eps {
214 can_decrease = false;
215 }
216 }
217 if let Some(ub) = bounds.upper[i] {
218 if x[i] >= ub - base_opts.eps {
219 can_increase = false;
220 }
221 }
222
223 if (g[i] > 0.0 && !can_decrease) || (g[i] < 0.0 && !can_increase) {
225 p[i] = 0.0;
226 }
227 }
228
229 if p.mapv(|pi| pi.abs()).sum() < 1e-10 {
231 return_array_to_pool(&mut memory_pool, p);
232 break;
233 }
234 }
235
236 let alpha_init = 1.0;
238 let (alpha, f_new) = backtracking_line_search(
239 &mut fun,
240 &x.view(),
241 f,
242 &p.view(),
243 &g.view(),
244 alpha_init,
245 0.0001,
246 0.5,
247 bounds,
248 );
249
250 nfev += 1;
251
252 let s = get_array_from_pool(&mut memory_pool, n, |_| alpha * &p);
254 let x_new = &x + &s;
255
256 if array_norm_chunked(&s, options.chunk_size) < base_opts.xtol {
258 return_array_to_pool(&mut memory_pool, s);
259 return_array_to_pool(&mut memory_pool, p);
260 x = x_new;
261 break;
262 }
263
264 let g_new = streaming_grad.compute(&mut fun, &x_new.view())?;
266 nfev += n.div_ceil(options.chunk_size);
267
268 let y = get_array_from_pool(&mut memory_pool, n, |_| &g_new - &g);
270
271 if check_convergence(
273 f - f_new,
274 0.0,
275 g_new.mapv(|gi| gi.abs()).sum(),
276 base_opts.ftol,
277 0.0,
278 base_opts.gtol,
279 ) {
280 return_array_to_pool(&mut memory_pool, s);
281 return_array_to_pool(&mut memory_pool, y);
282 return_array_to_pool(&mut memory_pool, p);
283 x = x_new;
284 g = g_new;
285 break;
286 }
287
288 let s_dot_y = chunked_dot_product(&s, &y, options.chunk_size);
290 if s_dot_y > 1e-10 {
291 s_history.push_back(s);
293 y_history.push_back(y);
294
295 while s_history.len() > options.max_history {
297 if let Some(old_s) = s_history.pop_front() {
298 return_array_to_pool(&mut memory_pool, old_s);
299 }
300 if let Some(old_y) = y_history.pop_front() {
301 return_array_to_pool(&mut memory_pool, old_y);
302 }
303 }
304 } else {
305 return_array_to_pool(&mut memory_pool, s);
307 return_array_to_pool(&mut memory_pool, y);
308 }
309
310 return_array_to_pool(&mut memory_pool, p);
311
312 x = x_new;
314 f = f_new;
315 g = g_new;
316 iter += 1;
317 }
318
319 while let Some(s) = s_history.pop_front() {
321 return_array_to_pool(&mut memory_pool, s);
322 }
323 while let Some(y) = y_history.pop_front() {
324 return_array_to_pool(&mut memory_pool, y);
325 }
326
327 if let Some(bounds) = bounds {
329 bounds.project(x.as_slice_mut().unwrap());
330 }
331
332 let final_fun = fun(&x.view());
334
335 Ok(OptimizeResult {
336 x,
337 fun: final_fun,
338 nit: iter,
339 func_evals: nfev,
340 nfev,
341 success: iter < base_opts.max_iter,
342 message: if iter < base_opts.max_iter {
343 "Memory-efficient optimization terminated successfully.".to_string()
344 } else {
345 "Maximum iterations reached.".to_string()
346 },
347 jacobian: Some(g),
348 hessian: None,
349 })
350}
351
352#[allow(dead_code)]
354fn compute_lbfgs_direction_memory_efficient(
355 g: &Array1<f64>,
356 s_history: &VecDeque<Array1<f64>>,
357 y_history: &VecDeque<Array1<f64>>,
358 memory_pool: &mut Option<MemoryPool>,
359) -> Result<Array1<f64>, OptimizeError> {
360 let m = s_history.len();
361 if m == 0 {
362 return Ok(-g);
363 }
364
365 let n = g.len();
366 let mut q = get_array_from_pool(memory_pool, n, |_| g.clone());
367 let mut alpha = vec![0.0; m];
368
369 for i in (0..m).rev() {
371 let rho_i = 1.0 / y_history[i].dot(&s_history[i]);
372 alpha[i] = rho_i * s_history[i].dot(&q);
373 let temp = get_array_from_pool(memory_pool, n, |_| &q - alpha[i] * &y_history[i]);
374 return_array_to_pool(memory_pool, q);
375 q = temp;
376 }
377
378 let gamma = if m > 0 {
380 s_history[m - 1].dot(&y_history[m - 1]) / y_history[m - 1].dot(&y_history[m - 1])
381 } else {
382 1.0
383 };
384
385 let mut r = get_array_from_pool(memory_pool, n, |_| gamma * &q);
386 return_array_to_pool(memory_pool, q);
387
388 for i in 0..m {
390 let rho_i = 1.0 / y_history[i].dot(&s_history[i]);
391 let beta = rho_i * y_history[i].dot(&r);
392 let temp = get_array_from_pool(memory_pool, n, |_| &r + (alpha[i] - beta) * &s_history[i]);
393 return_array_to_pool(memory_pool, r);
394 r = temp;
395 }
396
397 let result = get_array_from_pool(memory_pool, n, |_| -&r);
399 return_array_to_pool(memory_pool, r);
400 Ok(result)
401}
402
403#[allow(dead_code)]
405fn get_array_from_pool<F>(
406 memory_pool: &mut Option<MemoryPool>,
407 size: usize,
408 init_fn: F,
409) -> Array1<f64>
410where
411 F: FnOnce(usize) -> Array1<f64>,
412{
413 match memory_pool {
414 Some(pool) => {
415 let mut array = pool.get_array(size);
416 if array.len() != size {
417 array = Array1::zeros(size);
418 }
419 let result = init_fn(size);
420 pool.return_array(array);
421 result
422 }
423 None => init_fn(size),
424 }
425}
426
427#[allow(dead_code)]
429fn return_array_to_pool(_memory_pool: &mut Option<MemoryPool>, array: Array1<f64>) {
430 if let Some(pool) = _memory_pool {
431 pool.return_array(array);
432 }
433 }
435
436#[allow(dead_code)]
438fn chunked_dot_product(a: &Array1<f64>, b: &Array1<f64>, chunk_size: usize) -> f64 {
439 let n = a.len();
440 let mut result = 0.0;
441
442 for chunk_start in (0..n).step_by(chunk_size) {
443 let chunk_end = std::cmp::min(chunk_start + chunk_size, n);
444 let a_chunk = a.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
445 let b_chunk = b.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
446 result += a_chunk.dot(&b_chunk);
447 }
448
449 result
450}
451
452#[allow(dead_code)]
454fn array_norm_chunked(array: &Array1<f64>, chunk_size: usize) -> f64 {
455 let n = array.len();
456 let mut sum_sq: f64 = 0.0;
457
458 for chunk_start in (0..n).step_by(chunk_size) {
459 let chunk_end = std::cmp::min(chunk_start + chunk_size, n);
460 let chunk = array.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
461 sum_sq += chunk.mapv(|x| x.powi(2)).sum();
462 }
463
464 sum_sq.sqrt()
465}
466
467#[allow(dead_code)]
469fn estimate_memory_usage(n: usize, maxhistory: usize) -> usize {
470 const F64_SIZE: usize = std::mem::size_of::<f64>();
472
473 let current_vars = 2 * n * F64_SIZE;
475
476 let history_size = 2 * maxhistory * n * F64_SIZE;
478
479 let temp_arrays = 4 * n * F64_SIZE;
481
482 current_vars + history_size + temp_arrays
483}
484
485#[allow(dead_code)]
487pub fn create_memory_efficient_optimizer(
488 problem_size: usize,
489 available_memory_mb: usize,
490) -> MemoryOptions {
491 let available_bytes = available_memory_mb * 1024 * 1024;
492
493 let max_history = std::cmp::min(
495 20,
496 available_bytes / (2 * problem_size * std::mem::size_of::<f64>() * 4),
497 )
498 .max(1);
499
500 let chunk_size = std::cmp::min(
501 problem_size,
502 std::cmp::max(64, available_bytes / (8 * std::mem::size_of::<f64>())),
503 );
504
505 MemoryOptions {
506 base_options: Options::default(),
507 max_memory_bytes: available_bytes,
508 chunk_size,
509 max_history,
510 use_memory_pool: true,
511 use_out_of_core: available_memory_mb < 100, temp_dir: None,
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use approx::assert_abs_diff_eq;
520
521 #[test]
522 fn test_memory_efficient_lbfgs_quadratic() {
523 let quadratic = |x: &ArrayView1<f64>| -> f64 {
524 x.mapv(|xi| xi.powi(2)).sum()
526 };
527
528 let n = 100; let x0 = Array1::ones(n);
530 let mut options = MemoryOptions::default();
531 options.chunk_size = 32; options.max_history = 5; let result = minimize_memory_efficient_lbfgs(quadratic, x0, &options).unwrap();
535
536 assert!(result.success);
537 for i in 0..std::cmp::min(10, n) {
539 assert_abs_diff_eq!(result.x[i], 0.0, epsilon = 1e-4);
540 }
541 }
542
543 #[test]
544 fn test_chunked_operations() {
545 let a = Array1::from_vec((0..100).map(|i| i as f64).collect());
546 let b = Array1::from_vec((0..100).map(|i| (i * 2) as f64).collect());
547
548 let dot_chunked = chunked_dot_product(&a, &b, 10);
550 let dot_normal = a.dot(&b);
551 assert_abs_diff_eq!(dot_chunked, dot_normal, epsilon = 1e-10);
552
553 let norm_chunked = array_norm_chunked(&a, 10);
555 let norm_normal = a.mapv(|x| x.powi(2)).sum().sqrt();
556 assert_abs_diff_eq!(norm_chunked, norm_normal, epsilon = 1e-10);
557 }
558
559 #[test]
560 fn test_memory_pool() {
561 let mut pool = MemoryPool::new(3);
562
563 let arr1 = pool.get_array(10);
565 let arr2 = pool.get_array(10);
566
567 pool.return_array(arr1);
568 pool.return_array(arr2);
569
570 let arr3 = pool.get_array(10);
572 let arr4 = pool.get_array(10);
573
574 pool.return_array(arr3);
575 pool.return_array(arr4);
576
577 assert_eq!(pool.array_pool.len(), 2);
578 }
579
580 #[test]
581 fn test_memory_estimation() {
582 let n = 1000;
583 let max_history = 10;
584 let estimated = estimate_memory_usage(n, max_history);
585
586 assert!(estimated > 0);
588 assert!(estimated < 1_000_000); }
590
591 #[test]
592 fn test_auto_parameter_selection() {
593 let options = create_memory_efficient_optimizer(10000, 64); assert!(options.chunk_size > 0);
596 assert!(options.max_history > 0);
597 assert!(options.max_memory_bytes > 0);
598 }
599}