sparse_ir/
working_buffer.rs1use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9pub struct WorkingBuffer {
20 ptr: NonNull<u8>,
22 capacity_bytes: usize,
24 layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29 pub fn new() -> Self {
31 Self {
32 ptr: NonNull::dangling(),
33 capacity_bytes: 0,
34 layout: None,
35 }
36 }
37
38 pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40 if capacity_bytes == 0 {
41 return Self::new();
42 }
43
44 let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47 let ptr = unsafe { alloc::alloc(layout) };
48 let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50 Self {
51 ptr,
52 capacity_bytes,
53 layout: Some(layout),
54 }
55 }
56
57 pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62 if required_bytes <= self.capacity_bytes {
63 return;
64 }
65
66 self.deallocate();
68
69 let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71 let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73 let ptr = unsafe { alloc::alloc(layout) };
74 self.ptr = NonNull::new(ptr).expect("Allocation failed");
75 self.capacity_bytes = new_capacity;
76 self.layout = Some(layout);
77 }
78
79 pub fn ensure_capacity<T>(&mut self, count: usize) {
81 let required_bytes = count * std::mem::size_of::<T>();
82 self.ensure_capacity_bytes(required_bytes);
83 }
84
85 #[allow(unsafe_op_in_unsafe_fn)]
92 pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
93 debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
94 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count) }
95 }
96
97 #[allow(unsafe_op_in_unsafe_fn)]
104 pub unsafe fn as_complex_slice_mut(
105 &mut self,
106 count: usize,
107 ) -> &mut [num_complex::Complex<f64>] {
108 debug_assert!(
109 count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
110 );
111 unsafe {
112 std::slice::from_raw_parts_mut(
113 self.ptr.as_ptr() as *mut num_complex::Complex<f64>,
114 count,
115 )
116 }
117 }
118
119 pub fn as_ptr(&self) -> *mut u8 {
121 self.ptr.as_ptr()
122 }
123
124 pub fn capacity_bytes(&self) -> usize {
126 self.capacity_bytes
127 }
128
129 fn deallocate(&mut self) {
131 if let Some(layout) = self.layout.take() {
132 unsafe {
133 alloc::dealloc(self.ptr.as_ptr(), layout);
134 }
135 }
136 self.ptr = NonNull::dangling();
137 self.capacity_bytes = 0;
138 }
139}
140
141impl Default for WorkingBuffer {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl Drop for WorkingBuffer {
148 fn drop(&mut self) {
149 self.deallocate();
150 }
151}
152
153unsafe impl Send for WorkingBuffer {}
155unsafe impl Sync for WorkingBuffer {}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use num_complex::Complex;
161
162 #[test]
163 fn test_working_buffer_new() {
164 let buf = WorkingBuffer::new();
165 assert_eq!(buf.capacity_bytes(), 0);
166 }
167
168 #[test]
169 fn test_working_buffer_with_capacity() {
170 let buf = WorkingBuffer::with_capacity_bytes(1024);
171 assert!(buf.capacity_bytes() >= 1024);
172 }
173
174 #[test]
175 fn test_working_buffer_ensure_capacity() {
176 let mut buf = WorkingBuffer::new();
177 buf.ensure_capacity::<f64>(100);
178 assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
179 }
180
181 #[test]
182 fn test_working_buffer_as_f64_slice() {
183 let mut buf = WorkingBuffer::new();
184 let count = 10;
185 buf.ensure_capacity::<f64>(count);
186
187 unsafe {
188 let slice = buf.as_f64_slice_mut(count);
189 assert_eq!(slice.len(), count);
190
191 for i in 0..count {
193 slice[i] = i as f64;
194 }
195
196 let slice = buf.as_f64_slice_mut(count);
198 for i in 0..count {
199 assert_eq!(slice[i], i as f64);
200 }
201 }
202 }
203
204 #[test]
205 fn test_working_buffer_as_complex_slice() {
206 let mut buf = WorkingBuffer::new();
207 let count = 10;
208 buf.ensure_capacity::<Complex<f64>>(count);
209
210 unsafe {
211 let slice = buf.as_complex_slice_mut(count);
212 assert_eq!(slice.len(), count);
213
214 for i in 0..count {
216 slice[i] = Complex::new(i as f64, (i + 1) as f64);
217 }
218
219 let slice = buf.as_complex_slice_mut(count);
221 for i in 0..count {
222 assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
223 }
224 }
225 }
226
227 #[test]
228 fn test_working_buffer_reallocation() {
229 let mut buf = WorkingBuffer::with_capacity_bytes(100);
230 let old_capacity = buf.capacity_bytes();
231
232 buf.ensure_capacity_bytes(1000);
233 assert!(buf.capacity_bytes() >= 1000);
234 assert!(buf.capacity_bytes() > old_capacity);
235 }
236}