wraith/km/
memory.rs

1//! Kernel memory operations: MDL, physical memory, virtual memory
2
3use core::ffi::c_void;
4use core::ptr::NonNull;
5use alloc::vec::Vec;
6
7use super::allocator::{PoolBuffer, PoolType};
8use super::error::{status, KmError, KmResult, NtStatus};
9
10/// MDL flags
11#[repr(u32)]
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum MdlFlags {
14    /// MDL describes locked pages
15    MappedToSystemVa = 0x0001,
16    /// pages are from paged pool
17    PagesPaged = 0x0002,
18    /// allocated from lookaside
19    SourceIsNonpagedPool = 0x0004,
20    /// allocated with MDL_ALLOCATED_FIXED_SIZE
21    AllocatedFixedSize = 0x0008,
22    /// partial MDL
23    Partial = 0x0010,
24    /// partial MDL has been built
25    PartialHasBeenMapped = 0x0020,
26    /// locked using MmProbeAndLockPages
27    IoPageRead = 0x0040,
28    /// writeable
29    WriteOperation = 0x0080,
30    /// locked pages
31    LockedPages = 0x0100,
32    /// IO space
33    IoSpace = 0x0800,
34    /// network buffer
35    NetworkHeader = 0x1000,
36    /// MDL describes mapped pages
37    Mapping = 0x2000,
38    /// internal MDL flag
39    AllocatedMustSucceed = 0x4000,
40    /// internal MDL flag
41    Internal = 0x8000,
42}
43
44/// memory descriptor list wrapper
45#[repr(C)]
46pub struct MdlRaw {
47    pub next: *mut MdlRaw,
48    pub size: i16,
49    pub mdl_flags: i16,
50    pub process: *mut c_void,
51    pub mapped_system_va: *mut c_void,
52    pub start_va: *mut c_void,
53    pub byte_count: u32,
54    pub byte_offset: u32,
55    // PFN array follows
56}
57
58/// safe MDL wrapper with RAII cleanup
59pub struct Mdl {
60    raw: *mut MdlRaw,
61    locked: bool,
62    mapped: bool,
63    system_address: Option<NonNull<c_void>>,
64}
65
66impl Mdl {
67    /// create MDL for virtual address range
68    pub fn create(virtual_address: *mut c_void, length: usize) -> KmResult<Self> {
69        // SAFETY: IoAllocateMdl is safe to call
70        let raw = unsafe {
71            IoAllocateMdl(
72                virtual_address,
73                length as u32,
74                0, // not secondary
75                0, // don't charge quota
76                core::ptr::null_mut(), // no IRP
77            )
78        };
79
80        if raw.is_null() {
81            return Err(KmError::MdlOperationFailed {
82                reason: "IoAllocateMdl returned null",
83            });
84        }
85
86        Ok(Self {
87            raw,
88            locked: false,
89            mapped: false,
90            system_address: None,
91        })
92    }
93
94    /// lock pages in memory (for user-mode buffers)
95    pub fn lock_pages(&mut self, access_mode: AccessMode, operation: LockOperation) -> KmResult<()> {
96        if self.locked {
97            return Ok(());
98        }
99
100        // SAFETY: MDL is valid
101        let result = unsafe {
102            MmProbeAndLockPages(self.raw, access_mode as u8, operation as u32)
103        };
104
105        // MmProbeAndLockPages doesn't return status, it raises exception on failure
106        // in kernel we'd use SEH but in Rust we assume success
107        self.locked = true;
108        Ok(())
109    }
110
111    /// get system address for MDL
112    pub fn system_address(&mut self) -> KmResult<NonNull<c_void>> {
113        if let Some(addr) = self.system_address {
114            return Ok(addr);
115        }
116
117        // SAFETY: MDL is valid and pages are locked
118        let addr = unsafe {
119            MmGetSystemAddressForMdlSafe(self.raw, MmPriority::NormalPagePriority as u32)
120        };
121
122        let addr = NonNull::new(addr).ok_or(KmError::MdlOperationFailed {
123            reason: "MmGetSystemAddressForMdlSafe returned null",
124        })?;
125
126        self.system_address = Some(addr);
127        self.mapped = true;
128        Ok(addr)
129    }
130
131    /// get byte count
132    pub fn byte_count(&self) -> u32 {
133        // SAFETY: MDL is valid
134        unsafe { (*self.raw).byte_count }
135    }
136
137    /// get raw MDL pointer
138    pub fn as_raw(&self) -> *mut MdlRaw {
139        self.raw
140    }
141
142    /// unlock pages
143    pub fn unlock_pages(&mut self) {
144        if self.locked {
145            // SAFETY: MDL is valid and pages are locked
146            unsafe {
147                MmUnlockPages(self.raw);
148            }
149            self.locked = false;
150        }
151    }
152}
153
154impl Drop for Mdl {
155    fn drop(&mut self) {
156        if self.locked {
157            self.unlock_pages();
158        }
159        if !self.raw.is_null() {
160            // SAFETY: MDL was allocated by IoAllocateMdl
161            unsafe {
162                IoFreeMdl(self.raw);
163            }
164        }
165    }
166}
167
168/// processor mode for memory operations
169#[repr(u8)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum AccessMode {
172    KernelMode = 0,
173    UserMode = 1,
174}
175
176/// lock operation type
177#[repr(u32)]
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub enum LockOperation {
180    IoReadAccess = 0,
181    IoWriteAccess = 1,
182    IoModifyAccess = 2,
183}
184
185/// page priority for MDL mapping
186#[repr(u32)]
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum MmPriority {
189    LowPagePriority = 0,
190    NormalPagePriority = 16,
191    HighPagePriority = 32,
192}
193
194/// physical memory operations
195pub struct PhysicalMemory;
196
197impl PhysicalMemory {
198    /// read from physical address
199    pub fn read(physical_address: u64, buffer: &mut [u8]) -> KmResult<usize> {
200        if buffer.is_empty() {
201            return Ok(0);
202        }
203
204        let size = buffer.len();
205
206        // map physical to virtual
207        let phys_addr = PhysicalAddress(physical_address);
208        let va = unsafe {
209            MmMapIoSpace(phys_addr, size, MmCacheType::NonCached as u32)
210        };
211
212        if va.is_null() {
213            return Err(KmError::PhysicalMemoryFailed {
214                address: physical_address,
215                size,
216            });
217        }
218
219        // copy data
220        // SAFETY: va is valid for size bytes
221        unsafe {
222            core::ptr::copy_nonoverlapping(va as *const u8, buffer.as_mut_ptr(), size);
223            MmUnmapIoSpace(va, size);
224        }
225
226        Ok(size)
227    }
228
229    /// write to physical address
230    pub fn write(physical_address: u64, buffer: &[u8]) -> KmResult<usize> {
231        if buffer.is_empty() {
232            return Ok(0);
233        }
234
235        let size = buffer.len();
236
237        let phys_addr = PhysicalAddress(physical_address);
238        let va = unsafe {
239            MmMapIoSpace(phys_addr, size, MmCacheType::NonCached as u32)
240        };
241
242        if va.is_null() {
243            return Err(KmError::PhysicalMemoryFailed {
244                address: physical_address,
245                size,
246            });
247        }
248
249        // SAFETY: va is valid for size bytes
250        unsafe {
251            core::ptr::copy_nonoverlapping(buffer.as_ptr(), va as *mut u8, size);
252            MmUnmapIoSpace(va, size);
253        }
254
255        Ok(size)
256    }
257
258    /// get physical address for virtual address
259    pub fn get_physical_address(virtual_address: *const c_void) -> Option<u64> {
260        // SAFETY: MmGetPhysicalAddress is safe for valid VA
261        let phys = unsafe { MmGetPhysicalAddress(virtual_address) };
262        if phys.0 == 0 {
263            None
264        } else {
265            Some(phys.0)
266        }
267    }
268
269    /// check if physical address is valid
270    pub fn is_address_valid(physical_address: u64) -> bool {
271        let phys_addr = PhysicalAddress(physical_address);
272        // SAFETY: just checking address validity
273        unsafe { MmIsAddressValid(phys_addr.0 as *const c_void) != 0 }
274    }
275}
276
277/// cache type for memory mapping
278#[repr(u32)]
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280pub enum MmCacheType {
281    NonCached = 0,
282    Cached = 1,
283    WriteCombined = 2,
284    HardwareCoherentCached = 3,
285    NonCachedUnordered = 4,
286}
287
288/// physical address wrapper
289#[repr(transparent)]
290#[derive(Debug, Clone, Copy)]
291pub struct PhysicalAddress(pub u64);
292
293/// virtual memory operations for kernel
294pub struct VirtualMemory;
295
296impl VirtualMemory {
297    /// allocate virtual memory in kernel space
298    pub fn allocate(size: usize, protection: u32) -> KmResult<NonNull<c_void>> {
299        let mut region_size = size;
300        let mut base_address: *mut c_void = core::ptr::null_mut();
301
302        // SAFETY: kernel allocation
303        let status = unsafe {
304            ZwAllocateVirtualMemory(
305                -1isize as *mut c_void, // current process
306                &mut base_address,
307                0,
308                &mut region_size,
309                0x3000, // MEM_COMMIT | MEM_RESERVE
310                protection,
311            )
312        };
313
314        if !status::nt_success(status) {
315            return Err(KmError::VirtualMemoryFailed {
316                address: 0,
317                size,
318                reason: "ZwAllocateVirtualMemory failed",
319            });
320        }
321
322        NonNull::new(base_address).ok_or(KmError::VirtualMemoryFailed {
323            address: 0,
324            size,
325            reason: "allocation returned null",
326        })
327    }
328
329    /// free virtual memory
330    ///
331    /// # Safety
332    /// address must have been allocated by VirtualMemory::allocate
333    pub unsafe fn free(address: *mut c_void) -> KmResult<()> {
334        let mut base = address;
335        let mut size = 0usize;
336
337        // SAFETY: caller ensures address is valid
338        let status = unsafe {
339            ZwFreeVirtualMemory(
340                -1isize as *mut c_void,
341                &mut base,
342                &mut size,
343                0x8000, // MEM_RELEASE
344            )
345        };
346
347        if !status::nt_success(status) {
348            return Err(KmError::VirtualMemoryFailed {
349                address: address as usize as u64,
350                size: 0,
351                reason: "ZwFreeVirtualMemory failed",
352            });
353        }
354
355        Ok(())
356    }
357
358    /// change memory protection
359    pub fn protect(
360        address: *mut c_void,
361        size: usize,
362        new_protection: u32,
363    ) -> KmResult<u32> {
364        let mut old_protection = 0u32;
365        let mut region_size = size;
366        let mut base = address;
367
368        // SAFETY: valid parameters
369        let status = unsafe {
370            ZwProtectVirtualMemory(
371                -1isize as *mut c_void,
372                &mut base,
373                &mut region_size,
374                new_protection,
375                &mut old_protection,
376            )
377        };
378
379        if !status::nt_success(status) {
380            return Err(KmError::VirtualMemoryFailed {
381                address: address as usize as u64,
382                size,
383                reason: "ZwProtectVirtualMemory failed",
384            });
385        }
386
387        Ok(old_protection)
388    }
389}
390
391/// kernel-mode specific memory utilities
392pub struct KernelMemory;
393
394impl KernelMemory {
395    /// copy memory with exception handling
396    pub fn copy(
397        destination: *mut c_void,
398        source: *const c_void,
399        length: usize,
400    ) -> KmResult<()> {
401        if destination.is_null() || source.is_null() {
402            return Err(KmError::InvalidParameter {
403                context: "copy: null pointer",
404            });
405        }
406
407        // SAFETY: caller ensures pointers are valid
408        // kernel should wrap this in SEH
409        unsafe {
410            core::ptr::copy_nonoverlapping(source as *const u8, destination as *mut u8, length);
411        }
412
413        Ok(())
414    }
415
416    /// safe copy that handles exceptions (returns partial copy size)
417    pub fn safe_copy(
418        destination: *mut c_void,
419        source: *const c_void,
420        length: usize,
421    ) -> KmResult<usize> {
422        let mut bytes_copied = 0usize;
423
424        // SAFETY: MmCopyMemory handles exceptions
425        let status = unsafe {
426            MmCopyMemory(
427                destination,
428                MmCopyAddress { virtual_address: source },
429                length,
430                0, // MM_COPY_MEMORY_VIRTUAL
431                &mut bytes_copied,
432            )
433        };
434
435        if !status::nt_success(status) && bytes_copied == 0 {
436            return Err(KmError::VirtualMemoryFailed {
437                address: source as u64,
438                size: length,
439                reason: "MmCopyMemory failed",
440            });
441        }
442
443        Ok(bytes_copied)
444    }
445
446    /// check if address is valid
447    pub fn is_address_valid(address: *const c_void) -> bool {
448        if address.is_null() {
449            return false;
450        }
451        // SAFETY: just checking validity
452        unsafe { MmIsAddressValid(address) != 0 }
453    }
454
455    /// check if address range is valid
456    pub fn is_range_valid(address: *const c_void, size: usize) -> bool {
457        if address.is_null() || size == 0 {
458            return false;
459        }
460
461        let start = address as usize;
462        let end = start.saturating_add(size);
463
464        // check at page boundaries
465        let page_size = 0x1000usize;
466        let mut current = start;
467
468        while current < end {
469            if !Self::is_address_valid(current as *const c_void) {
470                return false;
471            }
472            current = current.saturating_add(page_size);
473        }
474
475        true
476    }
477
478    /// zero memory
479    pub fn zero(address: *mut c_void, size: usize) {
480        if !address.is_null() && size > 0 {
481            // SAFETY: caller ensures address is valid
482            unsafe {
483                core::ptr::write_bytes(address as *mut u8, 0, size);
484            }
485        }
486    }
487}
488
489/// memory copy address union
490#[repr(C)]
491union MmCopyAddress {
492    virtual_address: *const c_void,
493    physical_address: PhysicalAddress,
494}
495
496/// RAII guard for virtual memory protection changes
497pub struct ProtectionGuard {
498    address: *mut c_void,
499    size: usize,
500    old_protection: u32,
501}
502
503impl ProtectionGuard {
504    /// change protection with automatic restore on drop
505    pub fn new(
506        address: *mut c_void,
507        size: usize,
508        new_protection: u32,
509    ) -> KmResult<Self> {
510        let old_protection = VirtualMemory::protect(address, size, new_protection)?;
511        Ok(Self {
512            address,
513            size,
514            old_protection,
515        })
516    }
517
518    /// get old protection value
519    pub fn old_protection(&self) -> u32 {
520        self.old_protection
521    }
522}
523
524impl Drop for ProtectionGuard {
525    fn drop(&mut self) {
526        let _ = VirtualMemory::protect(self.address, self.size, self.old_protection);
527    }
528}
529
530// memory protection constants
531pub mod protection {
532    pub const PAGE_NOACCESS: u32 = 0x01;
533    pub const PAGE_READONLY: u32 = 0x02;
534    pub const PAGE_READWRITE: u32 = 0x04;
535    pub const PAGE_WRITECOPY: u32 = 0x08;
536    pub const PAGE_EXECUTE: u32 = 0x10;
537    pub const PAGE_EXECUTE_READ: u32 = 0x20;
538    pub const PAGE_EXECUTE_READWRITE: u32 = 0x40;
539    pub const PAGE_EXECUTE_WRITECOPY: u32 = 0x80;
540    pub const PAGE_GUARD: u32 = 0x100;
541    pub const PAGE_NOCACHE: u32 = 0x200;
542}
543
544// kernel memory functions
545extern "system" {
546    fn IoAllocateMdl(
547        VirtualAddress: *mut c_void,
548        Length: u32,
549        SecondaryBuffer: u8,
550        ChargeQuota: u8,
551        Irp: *mut c_void,
552    ) -> *mut MdlRaw;
553
554    fn IoFreeMdl(Mdl: *mut MdlRaw);
555
556    fn MmProbeAndLockPages(
557        MemoryDescriptorList: *mut MdlRaw,
558        AccessMode: u8,
559        Operation: u32,
560    );
561
562    fn MmUnlockPages(MemoryDescriptorList: *mut MdlRaw);
563
564    fn MmGetSystemAddressForMdlSafe(
565        Mdl: *mut MdlRaw,
566        Priority: u32,
567    ) -> *mut c_void;
568
569    fn MmMapIoSpace(
570        PhysicalAddress: PhysicalAddress,
571        NumberOfBytes: usize,
572        CacheType: u32,
573    ) -> *mut c_void;
574
575    fn MmUnmapIoSpace(BaseAddress: *mut c_void, NumberOfBytes: usize);
576
577    fn MmGetPhysicalAddress(BaseAddress: *const c_void) -> PhysicalAddress;
578
579    fn MmIsAddressValid(VirtualAddress: *const c_void) -> u8;
580
581    fn MmCopyMemory(
582        TargetAddress: *mut c_void,
583        SourceAddress: MmCopyAddress,
584        NumberOfBytes: usize,
585        Flags: u32,
586        NumberOfBytesTransferred: *mut usize,
587    ) -> NtStatus;
588
589    fn ZwAllocateVirtualMemory(
590        ProcessHandle: *mut c_void,
591        BaseAddress: *mut *mut c_void,
592        ZeroBits: usize,
593        RegionSize: *mut usize,
594        AllocationType: u32,
595        Protect: u32,
596    ) -> NtStatus;
597
598    fn ZwFreeVirtualMemory(
599        ProcessHandle: *mut c_void,
600        BaseAddress: *mut *mut c_void,
601        RegionSize: *mut usize,
602        FreeType: u32,
603    ) -> NtStatus;
604
605    fn ZwProtectVirtualMemory(
606        ProcessHandle: *mut c_void,
607        BaseAddress: *mut *mut c_void,
608        RegionSize: *mut usize,
609        NewProtect: u32,
610        OldProtect: *mut u32,
611    ) -> NtStatus;
612}