scirs2_optimize/unconstrained/
memory_efficient.rs

1//! Memory-efficient algorithms for large-scale optimization problems
2//!
3//! This module provides optimization algorithms designed to handle very large problems
4//! with limited memory by using chunked processing, streaming algorithms, and
5//! memory pool management.
6
7use crate::error::OptimizeError;
8use crate::unconstrained::line_search::backtracking_line_search;
9use crate::unconstrained::result::OptimizeResult;
10use crate::unconstrained::utils::check_convergence;
11use crate::unconstrained::Options;
12use scirs2_core::ndarray::{Array1, ArrayView1};
13use std::collections::VecDeque;
14
15/// Memory optimization options for large-scale problems
16#[derive(Debug, Clone)]
17pub struct MemoryOptions {
18    /// Base optimization options
19    pub base_options: Options,
20    /// Maximum memory usage in bytes (0 means unlimited)
21    pub max_memory_bytes: usize,
22    /// Chunk size for processing large vectors
23    pub chunk_size: usize,
24    /// Maximum history size for L-BFGS-style methods
25    pub max_history: usize,
26    /// Whether to use memory pooling
27    pub use_memory_pool: bool,
28    /// Whether to use out-of-core storage for very large problems
29    pub use_out_of_core: bool,
30    /// Temporary directory for out-of-core storage
31    pub temp_dir: Option<std::path::PathBuf>,
32}
33
34impl Default for MemoryOptions {
35    fn default() -> Self {
36        Self {
37            base_options: Options::default(),
38            max_memory_bytes: 0, // Unlimited by default
39            chunk_size: 1024,    // Process 1024 elements at a time
40            max_history: 10,     // Keep last 10 iterations
41            use_memory_pool: true,
42            use_out_of_core: false,
43            temp_dir: None,
44        }
45    }
46}
47
48/// Memory pool for reusing arrays to reduce allocations
49struct MemoryPool {
50    array_pool: VecDeque<Array1<f64>>,
51    max_pool_size: usize,
52}
53
54impl MemoryPool {
55    fn new(max_size: usize) -> Self {
56        Self {
57            array_pool: VecDeque::new(),
58            max_pool_size: max_size,
59        }
60    }
61
62    fn get_array(&mut self, size: usize) -> Array1<f64> {
63        // Try to reuse an existing array of the right size
64        for i in 0..self.array_pool.len() {
65            if self.array_pool[i].len() == size {
66                return self.array_pool.remove(i).unwrap();
67            }
68        }
69        // If no suitable array found, create a new one
70        Array1::zeros(size)
71    }
72
73    fn return_array(&mut self, mut array: Array1<f64>) {
74        if self.array_pool.len() < self.max_pool_size {
75            // Zero out the array for reuse
76            array.fill(0.0);
77            self.array_pool.push_back(array);
78        }
79        // If pool is full, just drop the array
80    }
81}
82
83/// Streaming gradient computation for very large problems
84struct StreamingGradient {
85    chunk_size: usize,
86    eps: f64,
87}
88
89impl StreamingGradient {
90    fn new(chunk_size: usize, eps: f64) -> Self {
91        Self { chunk_size, eps }
92    }
93
94    /// Compute gradient using chunked finite differences
95    fn compute<F, S>(&self, fun: &mut F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
96    where
97        F: FnMut(&ArrayView1<f64>) -> S,
98        S: Into<f64>,
99    {
100        let n = x.len();
101        let mut grad = Array1::zeros(n);
102        let f0 = fun(x).into();
103
104        let mut x_pert = x.to_owned();
105
106        // Process gradient computation in chunks to limit memory usage
107        for chunk_start in (0..n).step_by(self.chunk_size) {
108            let chunk_end = std::cmp::min(chunk_start + self.chunk_size, n);
109
110            for i in chunk_start..chunk_end {
111                let h = self.eps * (1.0 + x[i].abs());
112                x_pert[i] = x[i] + h;
113
114                let f_plus = fun(&x_pert.view()).into();
115
116                if !f_plus.is_finite() {
117                    return Err(OptimizeError::ComputationError(
118                        "Function returned non-finite value during gradient computation"
119                            .to_string(),
120                    ));
121                }
122
123                grad[i] = (f_plus - f0) / h;
124                x_pert[i] = x[i]; // Reset
125            }
126
127            // Optional: yield control to prevent blocking for too long
128            // In a real implementation, this could check for cancellation
129        }
130
131        Ok(grad)
132    }
133}
134
135/// Memory-efficient L-BFGS implementation with bounded history
136#[allow(dead_code)]
137pub fn minimize_memory_efficient_lbfgs<F, S>(
138    mut fun: F,
139    x0: Array1<f64>,
140    options: &MemoryOptions,
141) -> Result<OptimizeResult<S>, OptimizeError>
142where
143    F: FnMut(&ArrayView1<f64>) -> S + Clone,
144    S: Into<f64> + Clone,
145{
146    let n = x0.len();
147    let base_opts = &options.base_options;
148
149    // Initialize memory pool if enabled
150    let mut memory_pool = if options.use_memory_pool {
151        Some(MemoryPool::new(options.max_history * 2))
152    } else {
153        None
154    };
155
156    // Estimate memory usage
157    let estimated_memory = estimate_memory_usage(n, options.max_history);
158    if options.max_memory_bytes > 0 && estimated_memory > options.max_memory_bytes {
159        return Err(OptimizeError::ValueError(format!(
160            "Estimated memory usage ({} bytes) exceeds limit ({} bytes). Consider reducing max_history or chunk_size.",
161            estimated_memory, options.max_memory_bytes
162        )));
163    }
164
165    // Initialize variables
166    let mut x = x0.to_owned();
167    let bounds = base_opts.bounds.as_ref();
168
169    // Ensure initial point is within bounds
170    if let Some(bounds) = bounds {
171        bounds.project(x.as_slice_mut().unwrap());
172    }
173
174    let mut f = fun(&x.view()).into();
175
176    // Initialize streaming gradient computer
177    let streaming_grad = StreamingGradient::new(options.chunk_size, base_opts.eps);
178
179    // Calculate initial gradient
180    let mut g = streaming_grad.compute(&mut fun, &x.view())?;
181
182    // L-BFGS history with bounded size
183    let mut s_history: VecDeque<Array1<f64>> = VecDeque::new();
184    let mut y_history: VecDeque<Array1<f64>> = VecDeque::new();
185
186    // Initialize counters
187    let mut iter = 0;
188    let mut nfev = 1 + n.div_ceil(options.chunk_size); // Initial function evaluations
189
190    // Main loop
191    while iter < base_opts.max_iter {
192        // Check convergence on gradient
193        if g.mapv(|gi| gi.abs()).sum() < base_opts.gtol {
194            break;
195        }
196
197        // Compute search direction using L-BFGS two-loop recursion
198        let mut p = if s_history.is_empty() {
199            // Use steepest descent if no history
200            get_array_from_pool(&mut memory_pool, n, |_| -&g)
201        } else {
202            compute_lbfgs_direction_memory_efficient(&g, &s_history, &y_history, &mut memory_pool)?
203        };
204
205        // Project search direction for bounded optimization
206        if let Some(bounds) = bounds {
207            for i in 0..n {
208                let mut can_decrease = true;
209                let mut can_increase = true;
210
211                // Check if at boundary
212                if let Some(lb) = bounds.lower[i] {
213                    if x[i] <= lb + base_opts.eps {
214                        can_decrease = false;
215                    }
216                }
217                if let Some(ub) = bounds.upper[i] {
218                    if x[i] >= ub - base_opts.eps {
219                        can_increase = false;
220                    }
221                }
222
223                // Project gradient component
224                if (g[i] > 0.0 && !can_decrease) || (g[i] < 0.0 && !can_increase) {
225                    p[i] = 0.0;
226                }
227            }
228
229            // If no movement is possible, we're at a constrained optimum
230            if p.mapv(|pi| pi.abs()).sum() < 1e-10 {
231                return_array_to_pool(&mut memory_pool, p);
232                break;
233            }
234        }
235
236        // Line search
237        let alpha_init = 1.0;
238        let (alpha, f_new) = backtracking_line_search(
239            &mut fun,
240            &x.view(),
241            f,
242            &p.view(),
243            &g.view(),
244            alpha_init,
245            0.0001,
246            0.5,
247            bounds,
248        );
249
250        nfev += 1;
251
252        // Update position
253        let s = get_array_from_pool(&mut memory_pool, n, |_| alpha * &p);
254        let x_new = &x + &s;
255
256        // Check step size convergence
257        if array_norm_chunked(&s, options.chunk_size) < base_opts.xtol {
258            return_array_to_pool(&mut memory_pool, s);
259            return_array_to_pool(&mut memory_pool, p);
260            x = x_new;
261            break;
262        }
263
264        // Calculate new gradient using streaming approach
265        let g_new = streaming_grad.compute(&mut fun, &x_new.view())?;
266        nfev += n.div_ceil(options.chunk_size);
267
268        // Gradient difference
269        let y = get_array_from_pool(&mut memory_pool, n, |_| &g_new - &g);
270
271        // Check convergence on function value
272        if check_convergence(
273            f - f_new,
274            0.0,
275            g_new.mapv(|gi| gi.abs()).sum(),
276            base_opts.ftol,
277            0.0,
278            base_opts.gtol,
279        ) {
280            return_array_to_pool(&mut memory_pool, s);
281            return_array_to_pool(&mut memory_pool, y);
282            return_array_to_pool(&mut memory_pool, p);
283            x = x_new;
284            g = g_new;
285            break;
286        }
287
288        // Update L-BFGS history with memory management
289        let s_dot_y = chunked_dot_product(&s, &y, options.chunk_size);
290        if s_dot_y > 1e-10 {
291            // Add to history
292            s_history.push_back(s);
293            y_history.push_back(y);
294
295            // Remove oldest entries if history is too long
296            while s_history.len() > options.max_history {
297                if let Some(old_s) = s_history.pop_front() {
298                    return_array_to_pool(&mut memory_pool, old_s);
299                }
300                if let Some(old_y) = y_history.pop_front() {
301                    return_array_to_pool(&mut memory_pool, old_y);
302                }
303            }
304        } else {
305            // Don't update history, just return arrays to pool
306            return_array_to_pool(&mut memory_pool, s);
307            return_array_to_pool(&mut memory_pool, y);
308        }
309
310        return_array_to_pool(&mut memory_pool, p);
311
312        // Update variables for next iteration
313        x = x_new;
314        f = f_new;
315        g = g_new;
316        iter += 1;
317    }
318
319    // Clean up remaining arrays in history
320    while let Some(s) = s_history.pop_front() {
321        return_array_to_pool(&mut memory_pool, s);
322    }
323    while let Some(y) = y_history.pop_front() {
324        return_array_to_pool(&mut memory_pool, y);
325    }
326
327    // Final check for bounds
328    if let Some(bounds) = bounds {
329        bounds.project(x.as_slice_mut().unwrap());
330    }
331
332    // Use original function for final value
333    let final_fun = fun(&x.view());
334
335    Ok(OptimizeResult {
336        x,
337        fun: final_fun,
338        nit: iter,
339        func_evals: nfev,
340        nfev,
341        success: iter < base_opts.max_iter,
342        message: if iter < base_opts.max_iter {
343            "Memory-efficient optimization terminated successfully.".to_string()
344        } else {
345            "Maximum iterations reached.".to_string()
346        },
347        jacobian: Some(g),
348        hessian: None,
349    })
350}
351
352/// Memory-efficient L-BFGS direction computation
353#[allow(dead_code)]
354fn compute_lbfgs_direction_memory_efficient(
355    g: &Array1<f64>,
356    s_history: &VecDeque<Array1<f64>>,
357    y_history: &VecDeque<Array1<f64>>,
358    memory_pool: &mut Option<MemoryPool>,
359) -> Result<Array1<f64>, OptimizeError> {
360    let m = s_history.len();
361    if m == 0 {
362        return Ok(-g);
363    }
364
365    let n = g.len();
366    let mut q = get_array_from_pool(memory_pool, n, |_| g.clone());
367    let mut alpha = vec![0.0; m];
368
369    // First loop: backward through _history
370    for i in (0..m).rev() {
371        let rho_i = 1.0 / y_history[i].dot(&s_history[i]);
372        alpha[i] = rho_i * s_history[i].dot(&q);
373        let temp = get_array_from_pool(memory_pool, n, |_| &q - alpha[i] * &y_history[i]);
374        return_array_to_pool(memory_pool, q);
375        q = temp;
376    }
377
378    // Apply initial Hessian approximation (simple scaling)
379    let gamma = if m > 0 {
380        s_history[m - 1].dot(&y_history[m - 1]) / y_history[m - 1].dot(&y_history[m - 1])
381    } else {
382        1.0
383    };
384
385    let mut r = get_array_from_pool(memory_pool, n, |_| gamma * &q);
386    return_array_to_pool(memory_pool, q);
387
388    // Second loop: forward through _history
389    for i in 0..m {
390        let rho_i = 1.0 / y_history[i].dot(&s_history[i]);
391        let beta = rho_i * y_history[i].dot(&r);
392        let temp = get_array_from_pool(memory_pool, n, |_| &r + (alpha[i] - beta) * &s_history[i]);
393        return_array_to_pool(memory_pool, r);
394        r = temp;
395    }
396
397    // Return -r as the search direction
398    let result = get_array_from_pool(memory_pool, n, |_| -&r);
399    return_array_to_pool(memory_pool, r);
400    Ok(result)
401}
402
403/// Get array from memory pool or create new one
404#[allow(dead_code)]
405fn get_array_from_pool<F>(
406    memory_pool: &mut Option<MemoryPool>,
407    size: usize,
408    init_fn: F,
409) -> Array1<f64>
410where
411    F: FnOnce(usize) -> Array1<f64>,
412{
413    match memory_pool {
414        Some(pool) => {
415            let mut array = pool.get_array(size);
416            if array.len() != size {
417                array = Array1::zeros(size);
418            }
419            let result = init_fn(size);
420            pool.return_array(array);
421            result
422        }
423        None => init_fn(size),
424    }
425}
426
427/// Return array to memory pool
428#[allow(dead_code)]
429fn return_array_to_pool(_memory_pool: &mut Option<MemoryPool>, array: Array1<f64>) {
430    if let Some(pool) = _memory_pool {
431        pool.return_array(array);
432    }
433    // If no pool, array will be dropped normally
434}
435
436/// Compute dot product in chunks to reduce memory usage
437#[allow(dead_code)]
438fn chunked_dot_product(a: &Array1<f64>, b: &Array1<f64>, chunk_size: usize) -> f64 {
439    let n = a.len();
440    let mut result = 0.0;
441
442    for chunk_start in (0..n).step_by(chunk_size) {
443        let chunk_end = std::cmp::min(chunk_start + chunk_size, n);
444        let a_chunk = a.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
445        let b_chunk = b.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
446        result += a_chunk.dot(&b_chunk);
447    }
448
449    result
450}
451
452/// Compute array norm in chunks to reduce memory usage
453#[allow(dead_code)]
454fn array_norm_chunked(array: &Array1<f64>, chunk_size: usize) -> f64 {
455    let n = array.len();
456    let mut sum_sq: f64 = 0.0;
457
458    for chunk_start in (0..n).step_by(chunk_size) {
459        let chunk_end = std::cmp::min(chunk_start + chunk_size, n);
460        let chunk = array.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
461        sum_sq += chunk.mapv(|x| x.powi(2)).sum();
462    }
463
464    sum_sq.sqrt()
465}
466
467/// Estimate memory usage for given problem size and history
468#[allow(dead_code)]
469fn estimate_memory_usage(n: usize, maxhistory: usize) -> usize {
470    // Size of f64 in bytes
471    const F64_SIZE: usize = std::mem::size_of::<f64>();
472
473    // Current point and gradient
474    let current_vars = 2 * n * F64_SIZE;
475
476    // L-BFGS _history (s and y vectors)
477    let history_size = 2 * maxhistory * n * F64_SIZE;
478
479    // Temporary arrays for computation
480    let temp_arrays = 4 * n * F64_SIZE;
481
482    current_vars + history_size + temp_arrays
483}
484
485/// Create a memory-efficient optimizer with automatic parameter selection
486#[allow(dead_code)]
487pub fn create_memory_efficient_optimizer(
488    problem_size: usize,
489    available_memory_mb: usize,
490) -> MemoryOptions {
491    let available_bytes = available_memory_mb * 1024 * 1024;
492
493    // Estimate parameters based on available memory
494    let max_history = std::cmp::min(
495        20,
496        available_bytes / (2 * problem_size * std::mem::size_of::<f64>() * 4),
497    )
498    .max(1);
499
500    let chunk_size = std::cmp::min(
501        problem_size,
502        std::cmp::max(64, available_bytes / (8 * std::mem::size_of::<f64>())),
503    );
504
505    MemoryOptions {
506        base_options: Options::default(),
507        max_memory_bytes: available_bytes,
508        chunk_size,
509        max_history,
510        use_memory_pool: true,
511        use_out_of_core: available_memory_mb < 100, // Use out-of-core for very limited memory
512        temp_dir: None,
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use approx::assert_abs_diff_eq;
520
521    #[test]
522    fn test_memory_efficient_lbfgs_quadratic() {
523        let quadratic = |x: &ArrayView1<f64>| -> f64 {
524            // Simple quadratic: f(x) = sum(x_i^2)
525            x.mapv(|xi| xi.powi(2)).sum()
526        };
527
528        let n = 100; // Large problem
529        let x0 = Array1::ones(n);
530        let mut options = MemoryOptions::default();
531        options.chunk_size = 32; // Small chunks
532        options.max_history = 5; // Limited history
533
534        let result = minimize_memory_efficient_lbfgs(quadratic, x0, &options).unwrap();
535
536        assert!(result.success);
537        // Should converge to origin
538        for i in 0..std::cmp::min(10, n) {
539            assert_abs_diff_eq!(result.x[i], 0.0, epsilon = 1e-4);
540        }
541    }
542
543    #[test]
544    fn test_chunked_operations() {
545        let a = Array1::from_vec((0..100).map(|i| i as f64).collect());
546        let b = Array1::from_vec((0..100).map(|i| (i * 2) as f64).collect());
547
548        // Test chunked dot product
549        let dot_chunked = chunked_dot_product(&a, &b, 10);
550        let dot_normal = a.dot(&b);
551        assert_abs_diff_eq!(dot_chunked, dot_normal, epsilon = 1e-10);
552
553        // Test chunked norm
554        let norm_chunked = array_norm_chunked(&a, 10);
555        let norm_normal = a.mapv(|x| x.powi(2)).sum().sqrt();
556        assert_abs_diff_eq!(norm_chunked, norm_normal, epsilon = 1e-10);
557    }
558
559    #[test]
560    fn test_memory_pool() {
561        let mut pool = MemoryPool::new(3);
562
563        // Get and return arrays
564        let arr1 = pool.get_array(10);
565        let arr2 = pool.get_array(10);
566
567        pool.return_array(arr1);
568        pool.return_array(arr2);
569
570        // Should reuse arrays
571        let arr3 = pool.get_array(10);
572        let arr4 = pool.get_array(10);
573
574        pool.return_array(arr3);
575        pool.return_array(arr4);
576
577        assert_eq!(pool.array_pool.len(), 2);
578    }
579
580    #[test]
581    fn test_memory_estimation() {
582        let n = 1000;
583        let max_history = 10;
584        let estimated = estimate_memory_usage(n, max_history);
585
586        // Should be reasonable estimate (not zero, not too large)
587        assert!(estimated > 0);
588        assert!(estimated < 1_000_000); // Less than 1MB for this small problem
589    }
590
591    #[test]
592    fn test_auto_parameter_selection() {
593        let options = create_memory_efficient_optimizer(10000, 64); // 64MB available
594
595        assert!(options.chunk_size > 0);
596        assert!(options.max_history > 0);
597        assert!(options.max_memory_bytes > 0);
598    }
599}