Skip to main content

sparse_ir/
working_buffer.rs

1//! Working buffer for in-place operations
2//!
3//! Provides a reusable buffer that can be used for temporary storage during
4//! evaluate/fit operations, avoiding repeated allocations.
5
6use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9/// A reusable working buffer for temporary storage
10///
11/// This buffer manages raw memory that can be interpreted as different types
12/// (f64 or Complex<f64>) depending on the operation. It automatically grows
13/// when more space is needed.
14///
15/// # Safety
16/// This struct manages raw memory and must be used carefully:
17/// - The buffer is aligned for Complex<f64> (16 bytes)
18/// - When casting to different types, ensure alignment requirements are met
19pub struct WorkingBuffer {
20    /// Raw pointer to the buffer
21    ptr: NonNull<u8>,
22    /// Capacity in bytes
23    capacity_bytes: usize,
24    /// Current layout (for deallocation)
25    layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29    /// Create a new empty working buffer
30    pub fn new() -> Self {
31        Self {
32            ptr: NonNull::dangling(),
33            capacity_bytes: 0,
34            layout: None,
35        }
36    }
37
38    /// Create a new working buffer with initial capacity (in bytes)
39    pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40        if capacity_bytes == 0 {
41            return Self::new();
42        }
43
44        // Align to 16 bytes for Complex<f64> compatibility
45        let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47        let ptr = unsafe { alloc::alloc(layout) };
48        let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50        Self {
51            ptr,
52            capacity_bytes,
53            layout: Some(layout),
54        }
55    }
56
57    /// Ensure the buffer has at least the specified capacity in bytes
58    ///
59    /// If the current capacity is insufficient, the buffer is reallocated.
60    /// Existing data is NOT preserved.
61    pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62        if required_bytes <= self.capacity_bytes {
63            return;
64        }
65
66        // Deallocate old buffer if any
67        self.deallocate();
68
69        // Allocate new buffer with some extra room to avoid frequent reallocations
70        let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71        let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73        let ptr = unsafe { alloc::alloc(layout) };
74        self.ptr = NonNull::new(ptr).expect("Allocation failed");
75        self.capacity_bytes = new_capacity;
76        self.layout = Some(layout);
77    }
78
79    /// Ensure the buffer can hold at least `count` elements of type T
80    pub fn ensure_capacity<T>(&mut self, count: usize) {
81        let required_bytes = count * std::mem::size_of::<T>();
82        self.ensure_capacity_bytes(required_bytes);
83    }
84
85    /// Get the buffer as a mutable slice of f64
86    ///
87    /// # Safety
88    /// Caller must ensure:
89    /// - The buffer has enough capacity for `count` f64 elements
90    /// - No other references to this buffer exist
91    #[allow(unsafe_op_in_unsafe_fn)]
92    pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
93        debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
94        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count) }
95    }
96
97    /// Get the buffer as a mutable slice of Complex<f64>
98    ///
99    /// # Safety
100    /// Caller must ensure:
101    /// - The buffer has enough capacity for `count` Complex<f64> elements
102    /// - No other references to this buffer exist
103    #[allow(unsafe_op_in_unsafe_fn)]
104    pub unsafe fn as_complex_slice_mut(
105        &mut self,
106        count: usize,
107    ) -> &mut [num_complex::Complex<f64>] {
108        debug_assert!(
109            count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
110        );
111        unsafe {
112            std::slice::from_raw_parts_mut(
113                self.ptr.as_ptr() as *mut num_complex::Complex<f64>,
114                count,
115            )
116        }
117    }
118
119    /// Get the raw pointer
120    pub fn as_ptr(&self) -> *mut u8 {
121        self.ptr.as_ptr()
122    }
123
124    /// Get current capacity in bytes
125    pub fn capacity_bytes(&self) -> usize {
126        self.capacity_bytes
127    }
128
129    /// Deallocate the buffer
130    fn deallocate(&mut self) {
131        if let Some(layout) = self.layout.take() {
132            unsafe {
133                alloc::dealloc(self.ptr.as_ptr(), layout);
134            }
135        }
136        self.ptr = NonNull::dangling();
137        self.capacity_bytes = 0;
138    }
139}
140
141impl Default for WorkingBuffer {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl Drop for WorkingBuffer {
148    fn drop(&mut self) {
149        self.deallocate();
150    }
151}
152
153// WorkingBuffer owns its memory so it is safe to send across threads.
154// Note: Sync is intentionally NOT implemented because as_ptr(&self)
155// returns a raw *mut pointer, which could enable data races if shared.
156unsafe impl Send for WorkingBuffer {}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use num_complex::Complex;
162
163    #[test]
164    fn test_working_buffer_new() {
165        let buf = WorkingBuffer::new();
166        assert_eq!(buf.capacity_bytes(), 0);
167    }
168
169    #[test]
170    fn test_working_buffer_with_capacity() {
171        let buf = WorkingBuffer::with_capacity_bytes(1024);
172        assert!(buf.capacity_bytes() >= 1024);
173    }
174
175    #[test]
176    fn test_working_buffer_ensure_capacity() {
177        let mut buf = WorkingBuffer::new();
178        buf.ensure_capacity::<f64>(100);
179        assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
180    }
181
182    #[test]
183    fn test_working_buffer_as_f64_slice() {
184        let mut buf = WorkingBuffer::new();
185        let count = 10;
186        buf.ensure_capacity::<f64>(count);
187
188        unsafe {
189            let slice = buf.as_f64_slice_mut(count);
190            assert_eq!(slice.len(), count);
191
192            // Write some data
193            for i in 0..count {
194                slice[i] = i as f64;
195            }
196
197            // Read it back
198            let slice = buf.as_f64_slice_mut(count);
199            for i in 0..count {
200                assert_eq!(slice[i], i as f64);
201            }
202        }
203    }
204
205    #[test]
206    fn test_working_buffer_as_complex_slice() {
207        let mut buf = WorkingBuffer::new();
208        let count = 10;
209        buf.ensure_capacity::<Complex<f64>>(count);
210
211        unsafe {
212            let slice = buf.as_complex_slice_mut(count);
213            assert_eq!(slice.len(), count);
214
215            // Write some data
216            for i in 0..count {
217                slice[i] = Complex::new(i as f64, (i + 1) as f64);
218            }
219
220            // Read it back
221            let slice = buf.as_complex_slice_mut(count);
222            for i in 0..count {
223                assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
224            }
225        }
226    }
227
228    #[test]
229    fn test_working_buffer_reallocation() {
230        let mut buf = WorkingBuffer::with_capacity_bytes(100);
231        let old_capacity = buf.capacity_bytes();
232
233        buf.ensure_capacity_bytes(1000);
234        assert!(buf.capacity_bytes() >= 1000);
235        assert!(buf.capacity_bytes() > old_capacity);
236    }
237}