sparse_ir/
working_buffer.rs

1//! Working buffer for in-place operations
2//!
3//! Provides a reusable buffer that can be used for temporary storage during
4//! evaluate/fit operations, avoiding repeated allocations.
5
6use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9/// A reusable working buffer for temporary storage
10///
11/// This buffer manages raw memory that can be interpreted as different types
12/// (f64 or Complex<f64>) depending on the operation. It automatically grows
13/// when more space is needed.
14///
15/// # Safety
16/// This struct manages raw memory and must be used carefully:
17/// - The buffer is aligned for Complex<f64> (16 bytes)
18/// - When casting to different types, ensure alignment requirements are met
19pub struct WorkingBuffer {
20    /// Raw pointer to the buffer
21    ptr: NonNull<u8>,
22    /// Capacity in bytes
23    capacity_bytes: usize,
24    /// Current layout (for deallocation)
25    layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29    /// Create a new empty working buffer
30    pub fn new() -> Self {
31        Self {
32            ptr: NonNull::dangling(),
33            capacity_bytes: 0,
34            layout: None,
35        }
36    }
37
38    /// Create a new working buffer with initial capacity (in bytes)
39    pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40        if capacity_bytes == 0 {
41            return Self::new();
42        }
43
44        // Align to 16 bytes for Complex<f64> compatibility
45        let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47        let ptr = unsafe { alloc::alloc(layout) };
48        let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50        Self {
51            ptr,
52            capacity_bytes,
53            layout: Some(layout),
54        }
55    }
56
57    /// Ensure the buffer has at least the specified capacity in bytes
58    ///
59    /// If the current capacity is insufficient, the buffer is reallocated.
60    /// Existing data is NOT preserved.
61    pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62        if required_bytes <= self.capacity_bytes {
63            return;
64        }
65
66        // Deallocate old buffer if any
67        self.deallocate();
68
69        // Allocate new buffer with some extra room to avoid frequent reallocations
70        let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71        let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73        let ptr = unsafe { alloc::alloc(layout) };
74        self.ptr = NonNull::new(ptr).expect("Allocation failed");
75        self.capacity_bytes = new_capacity;
76        self.layout = Some(layout);
77    }
78
79    /// Ensure the buffer can hold at least `count` elements of type T
80    pub fn ensure_capacity<T>(&mut self, count: usize) {
81        let required_bytes = count * std::mem::size_of::<T>();
82        self.ensure_capacity_bytes(required_bytes);
83    }
84
85    /// Get the buffer as a mutable slice of f64
86    ///
87    /// # Safety
88    /// Caller must ensure:
89    /// - The buffer has enough capacity for `count` f64 elements
90    /// - No other references to this buffer exist
91    #[allow(unsafe_op_in_unsafe_fn)]
92    pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
93        debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
94        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count) }
95    }
96
97    /// Get the buffer as a mutable slice of Complex<f64>
98    ///
99    /// # Safety
100    /// Caller must ensure:
101    /// - The buffer has enough capacity for `count` Complex<f64> elements
102    /// - No other references to this buffer exist
103    #[allow(unsafe_op_in_unsafe_fn)]
104    pub unsafe fn as_complex_slice_mut(
105        &mut self,
106        count: usize,
107    ) -> &mut [num_complex::Complex<f64>] {
108        debug_assert!(
109            count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
110        );
111        unsafe {
112            std::slice::from_raw_parts_mut(
113                self.ptr.as_ptr() as *mut num_complex::Complex<f64>,
114                count,
115            )
116        }
117    }
118
119    /// Get the raw pointer
120    pub fn as_ptr(&self) -> *mut u8 {
121        self.ptr.as_ptr()
122    }
123
124    /// Get current capacity in bytes
125    pub fn capacity_bytes(&self) -> usize {
126        self.capacity_bytes
127    }
128
129    /// Deallocate the buffer
130    fn deallocate(&mut self) {
131        if let Some(layout) = self.layout.take() {
132            unsafe {
133                alloc::dealloc(self.ptr.as_ptr(), layout);
134            }
135        }
136        self.ptr = NonNull::dangling();
137        self.capacity_bytes = 0;
138    }
139}
140
141impl Default for WorkingBuffer {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl Drop for WorkingBuffer {
148    fn drop(&mut self) {
149        self.deallocate();
150    }
151}
152
153// WorkingBuffer is Send + Sync because it owns its memory
154unsafe impl Send for WorkingBuffer {}
155unsafe impl Sync for WorkingBuffer {}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use num_complex::Complex;
161
162    #[test]
163    fn test_working_buffer_new() {
164        let buf = WorkingBuffer::new();
165        assert_eq!(buf.capacity_bytes(), 0);
166    }
167
168    #[test]
169    fn test_working_buffer_with_capacity() {
170        let buf = WorkingBuffer::with_capacity_bytes(1024);
171        assert!(buf.capacity_bytes() >= 1024);
172    }
173
174    #[test]
175    fn test_working_buffer_ensure_capacity() {
176        let mut buf = WorkingBuffer::new();
177        buf.ensure_capacity::<f64>(100);
178        assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
179    }
180
181    #[test]
182    fn test_working_buffer_as_f64_slice() {
183        let mut buf = WorkingBuffer::new();
184        let count = 10;
185        buf.ensure_capacity::<f64>(count);
186
187        unsafe {
188            let slice = buf.as_f64_slice_mut(count);
189            assert_eq!(slice.len(), count);
190
191            // Write some data
192            for i in 0..count {
193                slice[i] = i as f64;
194            }
195
196            // Read it back
197            let slice = buf.as_f64_slice_mut(count);
198            for i in 0..count {
199                assert_eq!(slice[i], i as f64);
200            }
201        }
202    }
203
204    #[test]
205    fn test_working_buffer_as_complex_slice() {
206        let mut buf = WorkingBuffer::new();
207        let count = 10;
208        buf.ensure_capacity::<Complex<f64>>(count);
209
210        unsafe {
211            let slice = buf.as_complex_slice_mut(count);
212            assert_eq!(slice.len(), count);
213
214            // Write some data
215            for i in 0..count {
216                slice[i] = Complex::new(i as f64, (i + 1) as f64);
217            }
218
219            // Read it back
220            let slice = buf.as_complex_slice_mut(count);
221            for i in 0..count {
222                assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
223            }
224        }
225    }
226
227    #[test]
228    fn test_working_buffer_reallocation() {
229        let mut buf = WorkingBuffer::with_capacity_bytes(100);
230        let old_capacity = buf.capacity_bytes();
231
232        buf.ensure_capacity_bytes(1000);
233        assert!(buf.capacity_bytes() >= 1000);
234        assert!(buf.capacity_bytes() > old_capacity);
235    }
236}