sparse_ir/
working_buffer.rs1use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9pub struct WorkingBuffer {
20 ptr: NonNull<u8>,
22 capacity_bytes: usize,
24 layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29 pub fn new() -> Self {
31 Self {
32 ptr: NonNull::dangling(),
33 capacity_bytes: 0,
34 layout: None,
35 }
36 }
37
38 pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40 if capacity_bytes == 0 {
41 return Self::new();
42 }
43
44 let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47 let ptr = unsafe { alloc::alloc(layout) };
48 let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50 Self {
51 ptr,
52 capacity_bytes,
53 layout: Some(layout),
54 }
55 }
56
57 pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62 if required_bytes <= self.capacity_bytes {
63 return;
64 }
65
66 self.deallocate();
68
69 let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71 let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73 let ptr = unsafe { alloc::alloc(layout) };
74 self.ptr = NonNull::new(ptr).expect("Allocation failed");
75 self.capacity_bytes = new_capacity;
76 self.layout = Some(layout);
77 }
78
79 pub fn ensure_capacity<T>(&mut self, count: usize) {
81 let required_bytes = count * std::mem::size_of::<T>();
82 self.ensure_capacity_bytes(required_bytes);
83 }
84
85 #[allow(unsafe_op_in_unsafe_fn)]
92 pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
93 debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
94 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count) }
95 }
96
97 #[allow(unsafe_op_in_unsafe_fn)]
104 pub unsafe fn as_complex_slice_mut(
105 &mut self,
106 count: usize,
107 ) -> &mut [num_complex::Complex<f64>] {
108 debug_assert!(
109 count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
110 );
111 unsafe {
112 std::slice::from_raw_parts_mut(
113 self.ptr.as_ptr() as *mut num_complex::Complex<f64>,
114 count,
115 )
116 }
117 }
118
119 pub fn as_ptr(&self) -> *mut u8 {
121 self.ptr.as_ptr()
122 }
123
124 pub fn capacity_bytes(&self) -> usize {
126 self.capacity_bytes
127 }
128
129 fn deallocate(&mut self) {
131 if let Some(layout) = self.layout.take() {
132 unsafe {
133 alloc::dealloc(self.ptr.as_ptr(), layout);
134 }
135 }
136 self.ptr = NonNull::dangling();
137 self.capacity_bytes = 0;
138 }
139}
140
141impl Default for WorkingBuffer {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl Drop for WorkingBuffer {
148 fn drop(&mut self) {
149 self.deallocate();
150 }
151}
152
153unsafe impl Send for WorkingBuffer {}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use num_complex::Complex;
162
163 #[test]
164 fn test_working_buffer_new() {
165 let buf = WorkingBuffer::new();
166 assert_eq!(buf.capacity_bytes(), 0);
167 }
168
169 #[test]
170 fn test_working_buffer_with_capacity() {
171 let buf = WorkingBuffer::with_capacity_bytes(1024);
172 assert!(buf.capacity_bytes() >= 1024);
173 }
174
175 #[test]
176 fn test_working_buffer_ensure_capacity() {
177 let mut buf = WorkingBuffer::new();
178 buf.ensure_capacity::<f64>(100);
179 assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
180 }
181
182 #[test]
183 fn test_working_buffer_as_f64_slice() {
184 let mut buf = WorkingBuffer::new();
185 let count = 10;
186 buf.ensure_capacity::<f64>(count);
187
188 unsafe {
189 let slice = buf.as_f64_slice_mut(count);
190 assert_eq!(slice.len(), count);
191
192 for i in 0..count {
194 slice[i] = i as f64;
195 }
196
197 let slice = buf.as_f64_slice_mut(count);
199 for i in 0..count {
200 assert_eq!(slice[i], i as f64);
201 }
202 }
203 }
204
205 #[test]
206 fn test_working_buffer_as_complex_slice() {
207 let mut buf = WorkingBuffer::new();
208 let count = 10;
209 buf.ensure_capacity::<Complex<f64>>(count);
210
211 unsafe {
212 let slice = buf.as_complex_slice_mut(count);
213 assert_eq!(slice.len(), count);
214
215 for i in 0..count {
217 slice[i] = Complex::new(i as f64, (i + 1) as f64);
218 }
219
220 let slice = buf.as_complex_slice_mut(count);
222 for i in 0..count {
223 assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
224 }
225 }
226 }
227
228 #[test]
229 fn test_working_buffer_reallocation() {
230 let mut buf = WorkingBuffer::with_capacity_bytes(100);
231 let old_capacity = buf.capacity_bytes();
232
233 buf.ensure_capacity_bytes(1000);
234 assert!(buf.capacity_bytes() >= 1000);
235 assert!(buf.capacity_bytes() > old_capacity);
236 }
237}