Skip to main content

scirs2_ndimage/
memory_management.rs

1//! Memory management utilities for ndimage operations
2//!
3//! This module provides utilities for efficient memory management including:
4//! - Views vs. copies control
5//! - In-place operation options
6//! - Memory footprint optimization
7//! - Buffer reuse strategies
8
9use scirs2_core::ndarray::{Array, Array2, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13use std::mem;
14
15use crate::error::{NdimageError, NdimageResult};
16
17/// Memory allocation strategy for operations
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum MemoryStrategy {
20    /// Always create new arrays (safest, most memory)
21    AlwaysCopy,
22    /// Use views when possible, copy when necessary
23    PreferView,
24    /// Reuse input array when possible (in-place operations)
25    InPlace,
26    /// Use pre-allocated buffer
27    ReuseBuffer,
28}
29
30/// Configuration for memory-efficient operations
31#[derive(Debug, Clone)]
32pub struct MemoryConfig {
33    /// Memory allocation strategy
34    pub strategy: MemoryStrategy,
35    /// Maximum memory to use (in bytes)
36    pub memory_limit: Option<usize>,
37    /// Whether to allow in-place operations
38    pub allow_inplace: bool,
39    /// Whether to prefer contiguous memory layout
40    pub prefer_contiguous: bool,
41}
42
43impl Default for MemoryConfig {
44    fn default() -> Self {
45        Self {
46            strategy: MemoryStrategy::PreferView,
47            memory_limit: None,
48            allow_inplace: false,
49            prefer_contiguous: true,
50        }
51    }
52}
53
54/// Buffer pool for reusing allocated arrays
55pub struct BufferPool<T, D> {
56    buffers: Vec<Array<T, D>>,
57    max_buffers: usize,
58    _phantom: PhantomData<T>,
59}
60
61impl<T: Float + FromPrimitive + Debug + Clone, D: Dimension> BufferPool<T, D> {
62    pub fn new(maxbuffers: usize) -> Self {
63        Self {
64            buffers: Vec::new(),
65            max_buffers: maxbuffers,
66            _phantom: PhantomData,
67        }
68    }
69
70    /// Get a buffer from the pool or allocate a new one
71    pub fn get_buffer(&mut self, shape: D) -> Array<T, D> {
72        // Try to find a buffer with matching shape
73        if let Some(pos) = self.buffers.iter().position(|b| b.raw_dim() == shape) {
74            self.buffers.swap_remove(pos)
75        } else {
76            // Allocate new buffer
77            Array::zeros(shape)
78        }
79    }
80
81    /// Return a buffer to the pool
82    pub fn return_buffer(&mut self, buffer: Array<T, D>) {
83        if self.buffers.len() < self.max_buffers {
84            self.buffers.push(buffer);
85        }
86    }
87
88    /// Clear all buffers from the pool
89    pub fn clear(&mut self) {
90        self.buffers.clear();
91    }
92
93    /// Get the number of buffers currently in the pool
94    pub fn len(&self) -> usize {
95        self.buffers.len()
96    }
97
98    /// Check if the pool is empty
99    pub fn is_empty(&self) -> bool {
100        self.buffers.is_empty()
101    }
102}
103
104/// Trait for operations that can be performed in-place
105pub trait InPlaceOp<T, D>
106where
107    T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
108    D: Dimension + 'static,
109{
110    /// Check if this operation can be performed in-place
111    fn can_operate_inplace(&self) -> bool;
112
113    /// Perform the operation in-place
114    fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()>;
115
116    /// Perform the operation out-of-place
117    fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>>;
118}
119
120/// Memory-efficient wrapper for array operations
121pub struct MemoryEfficientOp<T, D> {
122    config: MemoryConfig,
123    phantom: PhantomData<(T, D)>,
124}
125
126impl<
127        T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
128        D: Dimension + 'static,
129    > MemoryEfficientOp<T, D>
130{
131    pub fn new(config: MemoryConfig) -> Self {
132        Self {
133            config,
134            phantom: PhantomData,
135        }
136    }
137
138    /// Execute an operation with memory efficiency considerations
139    pub fn execute<Op, S>(&self, input: &ArrayBase<S, D>, op: Op) -> NdimageResult<Array<T, D>>
140    where
141        S: Data<Elem = T>,
142        Op: InPlaceOp<T, D>,
143    {
144        match self.config.strategy {
145            MemoryStrategy::AlwaysCopy => {
146                // Always create a new array
147                op.operate_out_of_place(&input.view())
148            }
149            MemoryStrategy::PreferView => {
150                // Use view when possible
151                op.operate_out_of_place(&input.view())
152            }
153            MemoryStrategy::InPlace => {
154                if self.config.allow_inplace && op.can_operate_inplace() {
155                    // Try to operate in-place if we own the data
156                    let mut output = input.to_owned();
157                    op.operate_inplace(&mut output.view_mut())?;
158                    Ok(output)
159                } else {
160                    // Fall back to out-of-place
161                    op.operate_out_of_place(&input.view())
162                }
163            }
164            MemoryStrategy::ReuseBuffer => {
165                // This would require a buffer pool passed in
166                op.operate_out_of_place(&input.view())
167            }
168        }
169    }
170}
171
172/// Estimate memory usage for an operation
173#[allow(dead_code)]
174pub fn estimate_memory_usage<T, D>(shape: &[usize]) -> usize
175where
176    T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
177    D: Dimension + 'static,
178{
179    let elements: usize = shape.iter().product();
180    elements * std::mem::size_of::<T>()
181}
182
183/// Check if an operation would exceed memory limit
184#[allow(dead_code)]
185pub fn check_memory_limit<T, D>(shape: &[usize], limit: Option<usize>) -> NdimageResult<()>
186where
187    T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
188    D: Dimension + 'static,
189{
190    if let Some(max_bytes) = limit {
191        let required = estimate_memory_usage::<T, D>(shape);
192        if required > max_bytes {
193            return Err(NdimageError::MemoryError(format!(
194                "Operation would require {} bytes, exceeding limit of {} bytes",
195                required, max_bytes
196            )));
197        }
198    }
199    Ok(())
200}
201
202/// Create a memory-efficient view or copy based on configuration
203#[allow(dead_code)]
204pub fn create_output_array<T, D, S>(
205    input: &ArrayBase<S, D>,
206    config: &MemoryConfig,
207) -> NdimageResult<Array<T, D>>
208where
209    T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
210    D: Dimension + 'static,
211    S: Data<Elem = T>,
212{
213    let shape = input.shape();
214    check_memory_limit::<T, D>(shape, config.memory_limit)?;
215
216    let output = if config.prefer_contiguous && !input.is_standard_layout() {
217        // Create contiguous copy
218        input.to_owned().as_standard_layout().to_owned()
219    } else {
220        // Create regular copy
221        input.to_owned()
222    };
223
224    Ok(output)
225}
226
227/// Example in-place operation: element-wise square
228pub struct SquareOp;
229
230impl<
231        T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
232        D: Dimension + 'static,
233    > InPlaceOp<T, D> for SquareOp
234{
235    fn can_operate_inplace(&self) -> bool {
236        true
237    }
238
239    fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()> {
240        data.mapv_inplace(|x| x * x);
241        Ok(())
242    }
243
244    fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>> {
245        Ok(data.mapv(|x| x * x))
246    }
247}
248
249/// Example in-place operation: threshold
250pub struct ThresholdOp<T> {
251    threshold: T,
252    value: T,
253}
254
255impl<
256        T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
257    > ThresholdOp<T>
258{
259    pub fn new(threshold: T, value: T) -> Self {
260        Self { threshold, value }
261    }
262}
263
264impl<
265        T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
266        D: Dimension + 'static,
267    > InPlaceOp<T, D> for ThresholdOp<T>
268{
269    fn can_operate_inplace(&self) -> bool {
270        true
271    }
272
273    fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()> {
274        data.mapv_inplace(|x| if x > self.threshold { self.value } else { x });
275        Ok(())
276    }
277
278    fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>> {
279        Ok(data.mapv(|x| if x > self.threshold { self.value } else { x }))
280    }
281}
282
283/// Memory-efficient array slicing that avoids copies when possible
284#[allow(dead_code)]
285pub fn slice_efficiently<'a, T, D, S>(
286    array: &'a ArrayBase<S, D>,
287    _slice_info: &[std::ops::Range<usize>],
288) -> ArrayView<'a, T, D>
289where
290    T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
291    D: Dimension + 'static,
292    S: Data<Elem = T>,
293{
294    // This is a simplified version - in practice would use ndarray's slicing
295    array.view()
296}
297
298/// Zero-copy transpose for 2D arrays
299#[allow(dead_code)]
300pub fn transpose_view<T, S>(array: &ArrayBase<S, scirs2_core::ndarray::Ix2>) -> Array2<T>
301where
302    T: Float + Copy,
303    S: Data<Elem = T>,
304{
305    array.t().to_owned()
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use scirs2_core::ndarray::arr2;
312
313    #[test]
314    fn test_buffer_pool() {
315        let mut pool: BufferPool<f64, scirs2_core::ndarray::Ix2> = BufferPool::new(5);
316
317        // Get a buffer
318        let buf1 = pool.get_buffer(scirs2_core::ndarray::Ix2(10, 10));
319        assert_eq!(buf1.shape(), &[10, 10]);
320        assert_eq!(pool.len(), 0);
321
322        // Return it to the pool
323        pool.return_buffer(buf1);
324        assert_eq!(pool.len(), 1);
325
326        // Get it again - should reuse
327        let buf2 = pool.get_buffer(scirs2_core::ndarray::Ix2(10, 10));
328        assert_eq!(buf2.shape(), &[10, 10]);
329        assert_eq!(pool.len(), 0);
330    }
331
332    #[test]
333    fn test_memory_efficient_op() {
334        let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
335
336        // Test in-place operation
337        let config = MemoryConfig {
338            strategy: MemoryStrategy::InPlace,
339            allow_inplace: true,
340            ..Default::default()
341        };
342
343        let op_wrapper = MemoryEfficientOp::new(config);
344        let result = op_wrapper
345            .execute(&input, SquareOp)
346            .expect("Operation failed");
347
348        assert_eq!(result[[0, 0]], 1.0);
349        assert_eq!(result[[0, 1]], 4.0);
350        assert_eq!(result[[1, 0]], 9.0);
351        assert_eq!(result[[1, 1]], 16.0);
352    }
353
354    #[test]
355    fn test_memory_limit_check() {
356        // Check that small array passes
357        assert!(
358            check_memory_limit::<f64, scirs2_core::ndarray::Ix2>(&[10, 10], Some(1000)).is_ok()
359        );
360
361        // Check that large array fails
362        assert!(
363            check_memory_limit::<f64, scirs2_core::ndarray::Ix2>(&[1000, 1000], Some(1000))
364                .is_err()
365        );
366    }
367
368    #[test]
369    fn test_threshold_op() {
370        let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
371        let op = ThresholdOp::new(2.5, 0.0);
372
373        let config = MemoryConfig::default();
374        let op_wrapper = MemoryEfficientOp::new(config);
375        let result = op_wrapper.execute(&input, op).expect("Operation failed");
376
377        assert_eq!(result[[0, 0]], 1.0);
378        assert_eq!(result[[0, 1]], 2.0);
379        assert_eq!(result[[1, 0]], 0.0); // Thresholded
380        assert_eq!(result[[1, 1]], 0.0); // Thresholded
381    }
382}