1use scirs2_core::ndarray::{Array, Array2, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use std::fmt::Debug;
12use std::marker::PhantomData;
13use std::mem;
14
15use crate::error::{NdimageError, NdimageResult};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum MemoryStrategy {
20 AlwaysCopy,
22 PreferView,
24 InPlace,
26 ReuseBuffer,
28}
29
30#[derive(Debug, Clone)]
32pub struct MemoryConfig {
33 pub strategy: MemoryStrategy,
35 pub memory_limit: Option<usize>,
37 pub allow_inplace: bool,
39 pub prefer_contiguous: bool,
41}
42
43impl Default for MemoryConfig {
44 fn default() -> Self {
45 Self {
46 strategy: MemoryStrategy::PreferView,
47 memory_limit: None,
48 allow_inplace: false,
49 prefer_contiguous: true,
50 }
51 }
52}
53
54pub struct BufferPool<T, D> {
56 buffers: Vec<Array<T, D>>,
57 max_buffers: usize,
58 _phantom: PhantomData<T>,
59}
60
61impl<T: Float + FromPrimitive + Debug + Clone, D: Dimension> BufferPool<T, D> {
62 pub fn new(maxbuffers: usize) -> Self {
63 Self {
64 buffers: Vec::new(),
65 max_buffers: maxbuffers,
66 _phantom: PhantomData,
67 }
68 }
69
70 pub fn get_buffer(&mut self, shape: D) -> Array<T, D> {
72 if let Some(pos) = self.buffers.iter().position(|b| b.raw_dim() == shape) {
74 self.buffers.swap_remove(pos)
75 } else {
76 Array::zeros(shape)
78 }
79 }
80
81 pub fn return_buffer(&mut self, buffer: Array<T, D>) {
83 if self.buffers.len() < self.max_buffers {
84 self.buffers.push(buffer);
85 }
86 }
87
88 pub fn clear(&mut self) {
90 self.buffers.clear();
91 }
92
93 pub fn len(&self) -> usize {
95 self.buffers.len()
96 }
97
98 pub fn is_empty(&self) -> bool {
100 self.buffers.is_empty()
101 }
102}
103
104pub trait InPlaceOp<T, D>
106where
107 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
108 D: Dimension + 'static,
109{
110 fn can_operate_inplace(&self) -> bool;
112
113 fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()>;
115
116 fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>>;
118}
119
120pub struct MemoryEfficientOp<T, D> {
122 config: MemoryConfig,
123 phantom: PhantomData<(T, D)>,
124}
125
126impl<
127 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
128 D: Dimension + 'static,
129 > MemoryEfficientOp<T, D>
130{
131 pub fn new(config: MemoryConfig) -> Self {
132 Self {
133 config,
134 phantom: PhantomData,
135 }
136 }
137
138 pub fn execute<Op, S>(&self, input: &ArrayBase<S, D>, op: Op) -> NdimageResult<Array<T, D>>
140 where
141 S: Data<Elem = T>,
142 Op: InPlaceOp<T, D>,
143 {
144 match self.config.strategy {
145 MemoryStrategy::AlwaysCopy => {
146 op.operate_out_of_place(&input.view())
148 }
149 MemoryStrategy::PreferView => {
150 op.operate_out_of_place(&input.view())
152 }
153 MemoryStrategy::InPlace => {
154 if self.config.allow_inplace && op.can_operate_inplace() {
155 let mut output = input.to_owned();
157 op.operate_inplace(&mut output.view_mut())?;
158 Ok(output)
159 } else {
160 op.operate_out_of_place(&input.view())
162 }
163 }
164 MemoryStrategy::ReuseBuffer => {
165 op.operate_out_of_place(&input.view())
167 }
168 }
169 }
170}
171
172#[allow(dead_code)]
174pub fn estimate_memory_usage<T, D>(shape: &[usize]) -> usize
175where
176 T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
177 D: Dimension + 'static,
178{
179 let elements: usize = shape.iter().product();
180 elements * std::mem::size_of::<T>()
181}
182
183#[allow(dead_code)]
185pub fn check_memory_limit<T, D>(shape: &[usize], limit: Option<usize>) -> NdimageResult<()>
186where
187 T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
188 D: Dimension + 'static,
189{
190 if let Some(max_bytes) = limit {
191 let required = estimate_memory_usage::<T, D>(shape);
192 if required > max_bytes {
193 return Err(NdimageError::MemoryError(format!(
194 "Operation would require {} bytes, exceeding limit of {} bytes",
195 required, max_bytes
196 )));
197 }
198 }
199 Ok(())
200}
201
202#[allow(dead_code)]
204pub fn create_output_array<T, D, S>(
205 input: &ArrayBase<S, D>,
206 config: &MemoryConfig,
207) -> NdimageResult<Array<T, D>>
208where
209 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
210 D: Dimension + 'static,
211 S: Data<Elem = T>,
212{
213 let shape = input.shape();
214 check_memory_limit::<T, D>(shape, config.memory_limit)?;
215
216 let output = if config.prefer_contiguous && !input.is_standard_layout() {
217 input.to_owned().as_standard_layout().to_owned()
219 } else {
220 input.to_owned()
222 };
223
224 Ok(output)
225}
226
227pub struct SquareOp;
229
230impl<
231 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
232 D: Dimension + 'static,
233 > InPlaceOp<T, D> for SquareOp
234{
235 fn can_operate_inplace(&self) -> bool {
236 true
237 }
238
239 fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()> {
240 data.mapv_inplace(|x| x * x);
241 Ok(())
242 }
243
244 fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>> {
245 Ok(data.mapv(|x| x * x))
246 }
247}
248
249pub struct ThresholdOp<T> {
251 threshold: T,
252 value: T,
253}
254
255impl<
256 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
257 > ThresholdOp<T>
258{
259 pub fn new(threshold: T, value: T) -> Self {
260 Self { threshold, value }
261 }
262}
263
264impl<
265 T: Float + FromPrimitive + Debug + Clone + std::ops::AddAssign + std::ops::DivAssign + 'static,
266 D: Dimension + 'static,
267 > InPlaceOp<T, D> for ThresholdOp<T>
268{
269 fn can_operate_inplace(&self) -> bool {
270 true
271 }
272
273 fn operate_inplace(&self, data: &mut ArrayViewMut<T, D>) -> NdimageResult<()> {
274 data.mapv_inplace(|x| if x > self.threshold { self.value } else { x });
275 Ok(())
276 }
277
278 fn operate_out_of_place(&self, data: &ArrayView<T, D>) -> NdimageResult<Array<T, D>> {
279 Ok(data.mapv(|x| if x > self.threshold { self.value } else { x }))
280 }
281}
282
283#[allow(dead_code)]
285pub fn slice_efficiently<'a, T, D, S>(
286 array: &'a ArrayBase<S, D>,
287 _slice_info: &[std::ops::Range<usize>],
288) -> ArrayView<'a, T, D>
289where
290 T: Float + std::ops::AddAssign + std::ops::DivAssign + 'static,
291 D: Dimension + 'static,
292 S: Data<Elem = T>,
293{
294 array.view()
296}
297
298#[allow(dead_code)]
300pub fn transpose_view<T, S>(array: &ArrayBase<S, scirs2_core::ndarray::Ix2>) -> Array2<T>
301where
302 T: Float + Copy,
303 S: Data<Elem = T>,
304{
305 array.t().to_owned()
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use scirs2_core::ndarray::arr2;
312
313 #[test]
314 fn test_buffer_pool() {
315 let mut pool: BufferPool<f64, scirs2_core::ndarray::Ix2> = BufferPool::new(5);
316
317 let buf1 = pool.get_buffer(scirs2_core::ndarray::Ix2(10, 10));
319 assert_eq!(buf1.shape(), &[10, 10]);
320 assert_eq!(pool.len(), 0);
321
322 pool.return_buffer(buf1);
324 assert_eq!(pool.len(), 1);
325
326 let buf2 = pool.get_buffer(scirs2_core::ndarray::Ix2(10, 10));
328 assert_eq!(buf2.shape(), &[10, 10]);
329 assert_eq!(pool.len(), 0);
330 }
331
332 #[test]
333 fn test_memory_efficient_op() {
334 let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
335
336 let config = MemoryConfig {
338 strategy: MemoryStrategy::InPlace,
339 allow_inplace: true,
340 ..Default::default()
341 };
342
343 let op_wrapper = MemoryEfficientOp::new(config);
344 let result = op_wrapper
345 .execute(&input, SquareOp)
346 .expect("Operation failed");
347
348 assert_eq!(result[[0, 0]], 1.0);
349 assert_eq!(result[[0, 1]], 4.0);
350 assert_eq!(result[[1, 0]], 9.0);
351 assert_eq!(result[[1, 1]], 16.0);
352 }
353
354 #[test]
355 fn test_memory_limit_check() {
356 assert!(
358 check_memory_limit::<f64, scirs2_core::ndarray::Ix2>(&[10, 10], Some(1000)).is_ok()
359 );
360
361 assert!(
363 check_memory_limit::<f64, scirs2_core::ndarray::Ix2>(&[1000, 1000], Some(1000))
364 .is_err()
365 );
366 }
367
368 #[test]
369 fn test_threshold_op() {
370 let input = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
371 let op = ThresholdOp::new(2.5, 0.0);
372
373 let config = MemoryConfig::default();
374 let op_wrapper = MemoryEfficientOp::new(config);
375 let result = op_wrapper.execute(&input, op).expect("Operation failed");
376
377 assert_eq!(result[[0, 0]], 1.0);
378 assert_eq!(result[[0, 1]], 2.0);
379 assert_eq!(result[[1, 0]], 0.0); assert_eq!(result[[1, 1]], 0.0); }
382}