sparse_ir/
working_buffer.rs1use std::alloc::{self, Layout};
7use std::ptr::NonNull;
8
9pub struct WorkingBuffer {
20 ptr: NonNull<u8>,
22 capacity_bytes: usize,
24 layout: Option<Layout>,
26}
27
28impl WorkingBuffer {
29 pub fn new() -> Self {
31 Self {
32 ptr: NonNull::dangling(),
33 capacity_bytes: 0,
34 layout: None,
35 }
36 }
37
38 pub fn with_capacity_bytes(capacity_bytes: usize) -> Self {
40 if capacity_bytes == 0 {
41 return Self::new();
42 }
43
44 let layout = Layout::from_size_align(capacity_bytes, 16).expect("Invalid layout");
46
47 let ptr = unsafe { alloc::alloc(layout) };
48 let ptr = NonNull::new(ptr).expect("Allocation failed");
49
50 Self {
51 ptr,
52 capacity_bytes,
53 layout: Some(layout),
54 }
55 }
56
57 pub fn ensure_capacity_bytes(&mut self, required_bytes: usize) {
62 if required_bytes <= self.capacity_bytes {
63 return;
64 }
65
66 self.deallocate();
68
69 let new_capacity = required_bytes.max(required_bytes * 3 / 2);
71 let layout = Layout::from_size_align(new_capacity, 16).expect("Invalid layout");
72
73 let ptr = unsafe { alloc::alloc(layout) };
74 self.ptr = NonNull::new(ptr).expect("Allocation failed");
75 self.capacity_bytes = new_capacity;
76 self.layout = Some(layout);
77 }
78
79 pub fn ensure_capacity<T>(&mut self, count: usize) {
81 let required_bytes = count * std::mem::size_of::<T>();
82 self.ensure_capacity_bytes(required_bytes);
83 }
84
85 pub unsafe fn as_f64_slice_mut(&mut self, count: usize) -> &mut [f64] {
92 debug_assert!(count * std::mem::size_of::<f64>() <= self.capacity_bytes);
93 std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f64, count)
94 }
95
96 pub unsafe fn as_complex_slice_mut(
103 &mut self,
104 count: usize,
105 ) -> &mut [num_complex::Complex<f64>] {
106 debug_assert!(
107 count * std::mem::size_of::<num_complex::Complex<f64>>() <= self.capacity_bytes
108 );
109 std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut num_complex::Complex<f64>, count)
110 }
111
112 pub fn as_ptr(&self) -> *mut u8 {
114 self.ptr.as_ptr()
115 }
116
117 pub fn capacity_bytes(&self) -> usize {
119 self.capacity_bytes
120 }
121
122 fn deallocate(&mut self) {
124 if let Some(layout) = self.layout.take() {
125 unsafe {
126 alloc::dealloc(self.ptr.as_ptr(), layout);
127 }
128 }
129 self.ptr = NonNull::dangling();
130 self.capacity_bytes = 0;
131 }
132}
133
134impl Default for WorkingBuffer {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl Drop for WorkingBuffer {
141 fn drop(&mut self) {
142 self.deallocate();
143 }
144}
145
146unsafe impl Send for WorkingBuffer {}
148unsafe impl Sync for WorkingBuffer {}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use num_complex::Complex;
154
155 #[test]
156 fn test_working_buffer_new() {
157 let buf = WorkingBuffer::new();
158 assert_eq!(buf.capacity_bytes(), 0);
159 }
160
161 #[test]
162 fn test_working_buffer_with_capacity() {
163 let buf = WorkingBuffer::with_capacity_bytes(1024);
164 assert!(buf.capacity_bytes() >= 1024);
165 }
166
167 #[test]
168 fn test_working_buffer_ensure_capacity() {
169 let mut buf = WorkingBuffer::new();
170 buf.ensure_capacity::<f64>(100);
171 assert!(buf.capacity_bytes() >= 100 * std::mem::size_of::<f64>());
172 }
173
174 #[test]
175 fn test_working_buffer_as_f64_slice() {
176 let mut buf = WorkingBuffer::new();
177 let count = 10;
178 buf.ensure_capacity::<f64>(count);
179
180 unsafe {
181 let slice = buf.as_f64_slice_mut(count);
182 assert_eq!(slice.len(), count);
183
184 for i in 0..count {
186 slice[i] = i as f64;
187 }
188
189 let slice = buf.as_f64_slice_mut(count);
191 for i in 0..count {
192 assert_eq!(slice[i], i as f64);
193 }
194 }
195 }
196
197 #[test]
198 fn test_working_buffer_as_complex_slice() {
199 let mut buf = WorkingBuffer::new();
200 let count = 10;
201 buf.ensure_capacity::<Complex<f64>>(count);
202
203 unsafe {
204 let slice = buf.as_complex_slice_mut(count);
205 assert_eq!(slice.len(), count);
206
207 for i in 0..count {
209 slice[i] = Complex::new(i as f64, (i + 1) as f64);
210 }
211
212 let slice = buf.as_complex_slice_mut(count);
214 for i in 0..count {
215 assert_eq!(slice[i], Complex::new(i as f64, (i + 1) as f64));
216 }
217 }
218 }
219
220 #[test]
221 fn test_working_buffer_reallocation() {
222 let mut buf = WorkingBuffer::with_capacity_bytes(100);
223 let old_capacity = buf.capacity_bytes();
224
225 buf.ensure_capacity_bytes(1000);
226 assert!(buf.capacity_bytes() >= 1000);
227 assert!(buf.capacity_bytes() > old_capacity);
228 }
229}