1use core::alloc::{GlobalAlloc, Layout};
4use core::ffi::c_void;
5use core::ptr::NonNull;
6
7use super::error::{KmError, KmResult};
8
9#[repr(u32)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PoolType {
13 NonPaged = 0,
15 Paged = 1,
17 NonPagedNx = 512,
19 NonPagedSession = 32,
21 PagedSession = 33,
23}
24
25impl Default for PoolType {
26 fn default() -> Self {
27 Self::NonPagedNx
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct PoolTag(pub u32);
34
35impl PoolTag {
36 pub const fn from_chars(chars: [u8; 4]) -> Self {
38 Self(u32::from_le_bytes(chars))
39 }
40
41 pub const WRAITH: Self = Self::from_chars(*b"WRAT");
43}
44
45impl Default for PoolTag {
46 fn default() -> Self {
47 Self::WRAITH
48 }
49}
50
51pub struct PoolAllocator {
53 pool_type: PoolType,
54 tag: PoolTag,
55}
56
57impl PoolAllocator {
58 pub const fn new(pool_type: PoolType, tag: PoolTag) -> Self {
60 Self { pool_type, tag }
61 }
62
63 pub const fn non_paged() -> Self {
65 Self::new(PoolType::NonPagedNx, PoolTag::WRAITH)
66 }
67
68 pub const fn paged() -> Self {
70 Self::new(PoolType::Paged, PoolTag::WRAITH)
71 }
72
73 pub fn allocate(&self, size: usize) -> KmResult<NonNull<u8>> {
75 if size == 0 {
76 return Err(KmError::InvalidParameter {
77 context: "allocate: size cannot be zero",
78 });
79 }
80
81 let ptr = unsafe {
83 ExAllocatePoolWithTag(self.pool_type as u32, size, self.tag.0)
84 };
85
86 NonNull::new(ptr as *mut u8).ok_or(KmError::PoolAllocationFailed {
87 size,
88 pool_type: self.pool_type as u32,
89 })
90 }
91
92 pub fn allocate_zeroed(&self, size: usize) -> KmResult<NonNull<u8>> {
94 let ptr = self.allocate(size)?;
95 unsafe {
97 core::ptr::write_bytes(ptr.as_ptr(), 0, size);
98 }
99 Ok(ptr)
100 }
101
102 pub unsafe fn free(&self, ptr: NonNull<u8>) {
107 unsafe {
109 ExFreePoolWithTag(ptr.as_ptr() as *mut c_void, self.tag.0);
110 }
111 }
112
113 pub unsafe fn reallocate(
118 &self,
119 old_ptr: NonNull<u8>,
120 old_size: usize,
121 new_size: usize,
122 ) -> KmResult<NonNull<u8>> {
123 if new_size == 0 {
124 unsafe { self.free(old_ptr) };
126 return Err(KmError::InvalidParameter {
127 context: "reallocate: new_size cannot be zero",
128 });
129 }
130
131 let new_ptr = self.allocate(new_size)?;
132
133 unsafe {
135 let copy_size = core::cmp::min(old_size, new_size);
136 core::ptr::copy_nonoverlapping(old_ptr.as_ptr(), new_ptr.as_ptr(), copy_size);
137 self.free(old_ptr);
138 }
139
140 Ok(new_ptr)
141 }
142}
143
144pub struct KernelAllocator;
146
147impl KernelAllocator {
148 const ALLOCATOR: PoolAllocator = PoolAllocator::non_paged();
149}
150
151unsafe impl GlobalAlloc for KernelAllocator {
152 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
153 let align = layout.align();
156 let size = layout.size();
157
158 if align <= 16 {
159 match Self::ALLOCATOR.allocate(size) {
160 Ok(ptr) => ptr.as_ptr(),
161 Err(_) => core::ptr::null_mut(),
162 }
163 } else {
164 let total_size = size + align;
166 match Self::ALLOCATOR.allocate(total_size) {
167 Ok(ptr) => {
168 let raw = ptr.as_ptr() as usize;
169 let aligned = (raw + align - 1) & !(align - 1);
170 let aligned_ptr = aligned as *mut u8;
172 unsafe {
174 *((aligned_ptr as *mut usize).offset(-1)) = raw;
175 }
176 aligned_ptr
177 }
178 Err(_) => core::ptr::null_mut(),
179 }
180 }
181 }
182
183 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
184 if ptr.is_null() {
185 return;
186 }
187
188 let align = layout.align();
189
190 let actual_ptr = if align <= 16 {
191 ptr
192 } else {
193 let raw = unsafe { *((ptr as *mut usize).offset(-1)) };
196 raw as *mut u8
197 };
198
199 if let Some(ptr) = NonNull::new(actual_ptr) {
200 unsafe { Self::ALLOCATOR.free(ptr) };
202 }
203 }
204
205 unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
206 let new_layout = match Layout::from_size_align(new_size, layout.align()) {
207 Ok(l) => l,
208 Err(_) => return core::ptr::null_mut(),
209 };
210
211 unsafe {
213 let new_ptr = self.alloc(new_layout);
214 if !new_ptr.is_null() {
215 let copy_size = core::cmp::min(layout.size(), new_size);
216 core::ptr::copy_nonoverlapping(ptr, new_ptr, copy_size);
217 self.dealloc(ptr, layout);
218 }
219 new_ptr
220 }
221 }
222}
223
224pub struct PoolBuffer {
226 ptr: NonNull<u8>,
227 size: usize,
228 allocator: PoolAllocator,
229}
230
231impl PoolBuffer {
232 pub fn new(size: usize, pool_type: PoolType) -> KmResult<Self> {
234 let allocator = PoolAllocator::new(pool_type, PoolTag::WRAITH);
235 let ptr = allocator.allocate(size)?;
236 Ok(Self { ptr, size, allocator })
237 }
238
239 pub fn zeroed(size: usize, pool_type: PoolType) -> KmResult<Self> {
241 let allocator = PoolAllocator::new(pool_type, PoolTag::WRAITH);
242 let ptr = allocator.allocate_zeroed(size)?;
243 Ok(Self { ptr, size, allocator })
244 }
245
246 pub fn as_ptr(&self) -> *mut u8 {
248 self.ptr.as_ptr()
249 }
250
251 pub fn as_slice(&self) -> &[u8] {
253 unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
255 }
256
257 pub fn as_mut_slice(&mut self) -> &mut [u8] {
259 unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
261 }
262
263 pub fn size(&self) -> usize {
265 self.size
266 }
267
268 pub fn leak(self) -> NonNull<u8> {
270 let ptr = self.ptr;
271 core::mem::forget(self);
272 ptr
273 }
274}
275
276impl Drop for PoolBuffer {
277 fn drop(&mut self) {
278 unsafe { self.allocator.free(self.ptr) };
280 }
281}
282
283extern "system" {
285 fn ExAllocatePoolWithTag(PoolType: u32, NumberOfBytes: usize, Tag: u32) -> *mut c_void;
286 fn ExFreePoolWithTag(P: *mut c_void, Tag: u32);
287}