sparse_ir/
working_buffer.rs

1//! Working buffer for in-place operations
2//!
3//! Provides a reusable buffer that can be used for temporary storage during
4//! evaluate/fit operations, avoiding repeated allocations.
5
6use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9/// A reusable working buffer for temporary storage
10///
11/// This buffer manages raw memory that can be interpreted as different types
12/// (f64 or Complex<f64>) depending on the operation. It automatically grows
13/// when more space is needed.
14///
15/// # Safety
16/// This struct manages raw memory and must be used carefully:
17/// - The buffer is aligned for Complex<f64> (16 bytes)
18/// - When casting to different types, ensure alignment requirements are met
19pub struct WorkingBuffer {
20    /// Raw pointer to the buffer
21    ptr: NonNull<u8>,
22    /// Capacity in bytes
23    capacity_bytes: usize,
24    /// Current layout (for deallocation)
25    layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29    /// Create a new empty working buffer
30    pub fn new() -> Self {
31        Self {
32            ptr: NonNull::dangling(),
33            capacity_bytes: 0,
34            layout: None,
35        }
36    }
37
38    /// Create a new working buffer with initial capacity (in bytes)
39    pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40        if capacity_bytes == 0 {
41            return Self::new();
42        }
43
44        // Align to 16 bytes for Complex<f64> compatibility
45        let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47        let ptr = unsafe { alloc::alloc(layout) };
48        let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50        Self {
51            ptr,
52            capacity_bytes,
53            layout: Some(layout),
54        }
55    }
56
57    /// Ensure the buffer has at least the specified capacity in bytes
58    ///
59    /// If the current capacity is insufficient, the buffer is reallocated.
60    /// Existing data is NOT preserved.
61    pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62        if required_bytes <= self.capacity_bytes {
63            return;
64        }
65
66        // Deallocate old buffer if any
67        self.deallocate();
68
69        // Allocate new buffer with some extra room to avoid frequent reallocations
70        let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71        let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73        let ptr = unsafe { alloc::alloc(layout) };
74        self.ptr = NonNull::new(ptr).expect("Allocation failed");
75        self.capacity_bytes = new_capacity;
76        self.layout = Some(layout);
77    }
78
79    /// Ensure the buffer can hold at least `count` elements of type T
80    pub fn ensure_capacity<T>(&mut self, count: usize) {
81        let required_bytes = count * std::mem::size_of::<T>();
82        self.ensure_capacity_bytes(required_bytes);
83    }
84
85    /// Get the buffer as a mutable slice of f64
86    ///
87    /// # Safety
88    /// Caller must ensure:
89    /// - The buffer has enough capacity for `count` f64 elements
90    /// - No other references to this buffer exist
91    pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
92        debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
93        std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count)
94    }
95
96    /// Get the buffer as a mutable slice of Complex<f64>
97    ///
98    /// # Safety
99    /// Caller must ensure:
100    /// - The buffer has enough capacity for `count` Complex<f64> elements
101    /// - No other references to this buffer exist
102    pub unsafe fn as_complex_slice_mut(
103        &mut self,
104        count: usize,
105    ) -> &mut [num_complex::Complex<f64>] {
106        debug_assert!(
107            count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
108        );
109        std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut num_complex::Complex<f64>, count)
110    }
111
112    /// Get the raw pointer
113    pub fn as_ptr(&self) -> *mut u8 {
114        self.ptr.as_ptr()
115    }
116
117    /// Get current capacity in bytes
118    pub fn capacity_bytes(&self) -> usize {
119        self.capacity_bytes
120    }
121
122    /// Deallocate the buffer
123    fn deallocate(&mut self) {
124        if let Some(layout) = self.layout.take() {
125            unsafe {
126                alloc::dealloc(self.ptr.as_ptr(), layout);
127            }
128        }
129        self.ptr = NonNull::dangling();
130        self.capacity_bytes = 0;
131    }
132}
133
134impl Default for WorkingBuffer {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl Drop for WorkingBuffer {
141    fn drop(&mut self) {
142        self.deallocate();
143    }
144}
145
146// WorkingBuffer is Send + Sync because it owns its memory
147unsafe impl Send for WorkingBuffer {}
148unsafe impl Sync for WorkingBuffer {}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use num_complex::Complex;
154
155    #[test]
156    fn test_working_buffer_new() {
157        let buf = WorkingBuffer::new();
158        assert_eq!(buf.capacity_bytes(), 0);
159    }
160
161    #[test]
162    fn test_working_buffer_with_capacity() {
163        let buf = WorkingBuffer::with_capacity_bytes(1024);
164        assert!(buf.capacity_bytes() >= 1024);
165    }
166
167    #[test]
168    fn test_working_buffer_ensure_capacity() {
169        let mut buf = WorkingBuffer::new();
170        buf.ensure_capacity::<f64>(100);
171        assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
172    }
173
174    #[test]
175    fn test_working_buffer_as_f64_slice() {
176        let mut buf = WorkingBuffer::new();
177        let count = 10;
178        buf.ensure_capacity::<f64>(count);
179
180        unsafe {
181            let slice = buf.as_f64_slice_mut(count);
182            assert_eq!(slice.len(), count);
183
184            // Write some data
185            for i in 0..count {
186                slice[i] = i as f64;
187            }
188
189            // Read it back
190            let slice = buf.as_f64_slice_mut(count);
191            for i in 0..count {
192                assert_eq!(slice[i], i as f64);
193            }
194        }
195    }
196
197    #[test]
198    fn test_working_buffer_as_complex_slice() {
199        let mut buf = WorkingBuffer::new();
200        let count = 10;
201        buf.ensure_capacity::<Complex<f64>>(count);
202
203        unsafe {
204            let slice = buf.as_complex_slice_mut(count);
205            assert_eq!(slice.len(), count);
206
207            // Write some data
208            for i in 0..count {
209                slice[i] = Complex::new(i as f64, (i + 1) as f64);
210            }
211
212            // Read it back
213            let slice = buf.as_complex_slice_mut(count);
214            for i in 0..count {
215                assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
216            }
217        }
218    }
219
220    #[test]
221    fn test_working_buffer_reallocation() {
222        let mut buf = WorkingBuffer::with_capacity_bytes(100);
223        let old_capacity = buf.capacity_bytes();
224
225        buf.ensure_capacity_bytes(1000);
226        assert!(buf.capacity_bytes() >= 1000);
227        assert!(buf.capacity_bytes() > old_capacity);
228    }
229}