wraith/km_client/
driver.rs

1//! Driver handle and connection management
2
3use std::ffi::CString;
4use std::io;
5use std::mem::MaybeUninit;
6use std::ptr;
7
8use super::ioctl::IoctlCode;
9use super::process::ProcessOps;
10use super::{ClientError, ClientResult};
11
12/// handle to opened driver
13pub struct DriverHandle {
14    handle: *mut std::ffi::c_void,
15}
16
17impl DriverHandle {
18    /// open driver by symbolic link name
19    pub fn open(name: &str) -> ClientResult<Self> {
20        let name = CString::new(name).map_err(|_| ClientError::DriverNotFound)?;
21
22        let handle = unsafe {
23            CreateFileA(
24                name.as_ptr(),
25                GENERIC_READ | GENERIC_WRITE,
26                0,
27                ptr::null_mut(),
28                OPEN_EXISTING,
29                FILE_ATTRIBUTE_NORMAL,
30                ptr::null_mut(),
31            )
32        };
33
34        if handle == INVALID_HANDLE_VALUE {
35            return Err(ClientError::DriverOpenFailed(io::Error::last_os_error()));
36        }
37
38        Ok(Self { handle })
39    }
40
41    /// get raw handle
42    pub fn as_raw(&self) -> *mut std::ffi::c_void {
43        self.handle
44    }
45
46    /// send IOCTL with input and output buffers
47    pub fn ioctl<I, O>(
48        &self,
49        code: u32,
50        input: Option<&I>,
51        mut output: Option<&mut O>,
52    ) -> ClientResult<u32>
53    where
54        I: ?Sized,
55        O: ?Sized,
56    {
57        let input_ptr = input.map(|i| i as *const I as *const u8).unwrap_or(ptr::null());
58        let input_size = input.map(|i| std::mem::size_of_val(i) as u32).unwrap_or(0);
59
60        let output_size = output.as_ref().map(|o| std::mem::size_of_val(*o) as u32).unwrap_or(0);
61        let output_ptr = output
62            .as_mut()
63            .map(|o| *o as *mut O as *mut u8)
64            .unwrap_or(ptr::null_mut());
65
66        let mut bytes_returned = 0u32;
67
68        let result = unsafe {
69            DeviceIoControl(
70                self.handle,
71                code,
72                input_ptr as *const _,
73                input_size,
74                output_ptr as *mut _,
75                output_size,
76                &mut bytes_returned,
77                ptr::null_mut(),
78            )
79        };
80
81        if result == 0 {
82            return Err(ClientError::IoctlFailed {
83                code,
84                error: io::Error::last_os_error(),
85            });
86        }
87
88        Ok(bytes_returned)
89    }
90
91    /// send IOCTL with byte buffers
92    pub fn ioctl_raw(
93        &self,
94        code: u32,
95        input: &[u8],
96        output: &mut [u8],
97    ) -> ClientResult<u32> {
98        let mut bytes_returned = 0u32;
99
100        let result = unsafe {
101            DeviceIoControl(
102                self.handle,
103                code,
104                if input.is_empty() { ptr::null() } else { input.as_ptr() as *const _ },
105                input.len() as u32,
106                if output.is_empty() { ptr::null_mut() } else { output.as_mut_ptr() as *mut _ },
107                output.len() as u32,
108                &mut bytes_returned,
109                ptr::null_mut(),
110            )
111        };
112
113        if result == 0 {
114            return Err(ClientError::IoctlFailed {
115                code,
116                error: io::Error::last_os_error(),
117            });
118        }
119
120        Ok(bytes_returned)
121    }
122}
123
124impl Drop for DriverHandle {
125    fn drop(&mut self) {
126        if self.handle != INVALID_HANDLE_VALUE {
127            unsafe { CloseHandle(self.handle) };
128        }
129    }
130}
131
132// SAFETY: handle can be sent between threads
133unsafe impl Send for DriverHandle {}
134unsafe impl Sync for DriverHandle {}
135
136/// high-level driver client
137pub struct DriverClient {
138    handle: DriverHandle,
139}
140
141impl DriverClient {
142    /// connect to driver
143    pub fn connect(device_name: &str) -> ClientResult<Self> {
144        let handle = DriverHandle::open(device_name)?;
145        Ok(Self { handle })
146    }
147
148    /// get underlying handle
149    pub fn handle(&self) -> &DriverHandle {
150        &self.handle
151    }
152
153    /// open process for memory operations
154    pub fn open_process(&self, pid: u32) -> ClientResult<ProcessOps> {
155        ProcessOps::new(&self.handle, pid)
156    }
157
158    /// read value from remote process
159    pub fn read_process_memory<T: Copy>(&self, pid: u32, address: u64) -> ClientResult<T> {
160        use super::ioctl::{ReadMemoryRequest, codes};
161
162        let request = ReadMemoryRequest {
163            process_id: pid,
164            address,
165            size: std::mem::size_of::<T>() as u32,
166        };
167
168        // separate input and output buffers
169        let input_bytes = unsafe {
170            std::slice::from_raw_parts(
171                &request as *const _ as *const u8,
172                std::mem::size_of::<ReadMemoryRequest>(),
173            )
174        };
175
176        let mut output_buffer = vec![0u8; std::mem::size_of::<T>()];
177
178        let bytes = self.handle.ioctl_raw(
179            codes::READ_MEMORY.code(),
180            input_bytes,
181            &mut output_buffer,
182        )?;
183
184        if bytes as usize != std::mem::size_of::<T>() {
185            return Err(ClientError::InvalidResponse {
186                expected: std::mem::size_of::<T>(),
187                received: bytes as usize,
188            });
189        }
190
191        // read value from buffer
192        Ok(unsafe { ptr::read(output_buffer.as_ptr() as *const T) })
193    }
194
195    /// write value to remote process
196    pub fn write_process_memory<T: Copy>(&self, pid: u32, address: u64, value: &T) -> ClientResult<()> {
197        use super::ioctl::{WriteMemoryRequest, codes};
198
199        // create buffer with header + data
200        let header_size = std::mem::size_of::<WriteMemoryRequest>();
201        let data_size = std::mem::size_of::<T>();
202        let mut buffer = vec![0u8; header_size + data_size];
203
204        let request = WriteMemoryRequest {
205            process_id: pid,
206            address,
207            size: data_size as u32,
208        };
209
210        // copy header
211        unsafe {
212            ptr::copy_nonoverlapping(
213                &request as *const _ as *const u8,
214                buffer.as_mut_ptr(),
215                header_size,
216            );
217            // copy data
218            ptr::copy_nonoverlapping(
219                value as *const T as *const u8,
220                buffer.as_mut_ptr().add(header_size),
221                data_size,
222            );
223        }
224
225        let _ = self.handle.ioctl_raw(codes::WRITE_MEMORY.code(), &buffer, &mut [])?;
226        Ok(())
227    }
228
229    /// get module base address
230    pub fn get_module_base(&self, pid: u32, module_name: &str) -> ClientResult<u64> {
231        use super::ioctl::{GetModuleBaseRequest, GetModuleBaseResponse, codes};
232
233        let name_bytes = module_name.as_bytes();
234        let header_size = std::mem::size_of::<GetModuleBaseRequest>();
235        let total_size = header_size + name_bytes.len();
236
237        let mut input = vec![0u8; total_size];
238        let request = GetModuleBaseRequest {
239            process_id: pid,
240            module_name_offset: header_size as u32,
241            module_name_length: name_bytes.len() as u32,
242        };
243
244        unsafe {
245            ptr::copy_nonoverlapping(
246                &request as *const _ as *const u8,
247                input.as_mut_ptr(),
248                header_size,
249            );
250            ptr::copy_nonoverlapping(
251                name_bytes.as_ptr(),
252                input.as_mut_ptr().add(header_size),
253                name_bytes.len(),
254            );
255        }
256
257        let mut response = MaybeUninit::<GetModuleBaseResponse>::uninit();
258        let bytes = self.handle.ioctl(
259            codes::GET_MODULE_BASE.code(),
260            Some(&input[..]),
261            Some(unsafe { response.assume_init_mut() }),
262        )?;
263
264        if bytes as usize != std::mem::size_of::<GetModuleBaseResponse>() {
265            return Err(ClientError::InvalidResponse {
266                expected: std::mem::size_of::<GetModuleBaseResponse>(),
267                received: bytes as usize,
268            });
269        }
270
271        Ok(unsafe { response.assume_init() }.base_address)
272    }
273
274    /// allocate memory in remote process
275    pub fn allocate_memory(
276        &self,
277        pid: u32,
278        size: u64,
279        protection: u32,
280    ) -> ClientResult<u64> {
281        use super::ioctl::{AllocateMemoryRequest, AllocateMemoryResponse, codes};
282
283        let request = AllocateMemoryRequest {
284            process_id: pid,
285            size,
286            protection,
287            preferred_address: 0,
288        };
289
290        let mut response = MaybeUninit::<AllocateMemoryResponse>::uninit();
291        let bytes = self.handle.ioctl(
292            codes::ALLOCATE_MEMORY.code(),
293            Some(&request),
294            Some(unsafe { response.assume_init_mut() }),
295        )?;
296
297        if bytes as usize != std::mem::size_of::<AllocateMemoryResponse>() {
298            return Err(ClientError::InvalidResponse {
299                expected: std::mem::size_of::<AllocateMemoryResponse>(),
300                received: bytes as usize,
301            });
302        }
303
304        Ok(unsafe { response.assume_init() }.allocated_address)
305    }
306
307    /// free memory in remote process
308    pub fn free_memory(&self, pid: u32, address: u64) -> ClientResult<()> {
309        use super::ioctl::{FreeMemoryRequest, codes};
310
311        let request = FreeMemoryRequest {
312            process_id: pid,
313            address,
314        };
315
316        let _ = self.handle.ioctl(codes::FREE_MEMORY.code(), Some(&request), None::<&mut ()>)?;
317        Ok(())
318    }
319
320    /// change memory protection
321    pub fn protect_memory(
322        &self,
323        pid: u32,
324        address: u64,
325        size: u64,
326        protection: u32,
327    ) -> ClientResult<u32> {
328        use super::ioctl::{ProtectMemoryRequest, ProtectMemoryResponse, codes};
329
330        let request = ProtectMemoryRequest {
331            process_id: pid,
332            address,
333            size,
334            new_protection: protection,
335        };
336
337        let mut response = MaybeUninit::<ProtectMemoryResponse>::uninit();
338        let bytes = self.handle.ioctl(
339            codes::PROTECT_MEMORY.code(),
340            Some(&request),
341            Some(unsafe { response.assume_init_mut() }),
342        )?;
343
344        if bytes as usize != std::mem::size_of::<ProtectMemoryResponse>() {
345            return Err(ClientError::InvalidResponse {
346                expected: std::mem::size_of::<ProtectMemoryResponse>(),
347                received: bytes as usize,
348            });
349        }
350
351        Ok(unsafe { response.assume_init() }.old_protection)
352    }
353}
354
355// Windows API constants
356const GENERIC_READ: u32 = 0x80000000;
357const GENERIC_WRITE: u32 = 0x40000000;
358const OPEN_EXISTING: u32 = 3;
359const FILE_ATTRIBUTE_NORMAL: u32 = 0x80;
360const INVALID_HANDLE_VALUE: *mut std::ffi::c_void = -1isize as *mut _;
361
362// Windows API functions
363#[link(name = "kernel32")]
364extern "system" {
365    fn CreateFileA(
366        lpFileName: *const i8,
367        dwDesiredAccess: u32,
368        dwShareMode: u32,
369        lpSecurityAttributes: *mut std::ffi::c_void,
370        dwCreationDisposition: u32,
371        dwFlagsAndAttributes: u32,
372        hTemplateFile: *mut std::ffi::c_void,
373    ) -> *mut std::ffi::c_void;
374
375    fn CloseHandle(hObject: *mut std::ffi::c_void) -> i32;
376
377    fn DeviceIoControl(
378        hDevice: *mut std::ffi::c_void,
379        dwIoControlCode: u32,
380        lpInBuffer: *const std::ffi::c_void,
381        nInBufferSize: u32,
382        lpOutBuffer: *mut std::ffi::c_void,
383        nOutBufferSize: u32,
384        lpBytesReturned: *mut u32,
385        lpOverlapped: *mut std::ffi::c_void,
386    ) -> i32;
387}