wraith/manipulation/remote/
thread.rs

1//! Remote thread creation and manipulation
2
3use super::process::RemoteProcess;
4use crate::error::{Result, WraithError};
5use crate::manipulation::syscall::{
6    get_syscall_table, nt_close, nt_success, DirectSyscall, ObjectAttributes,
7};
8
9/// thread creation flags
10#[derive(Debug, Clone, Copy, Default)]
11pub struct ThreadCreationFlags {
12    pub flags: u32,
13}
14
15impl ThreadCreationFlags {
16    /// create thread in running state
17    pub const fn running() -> Self {
18        Self { flags: 0 }
19    }
20
21    /// create thread in suspended state
22    pub const fn suspended() -> Self {
23        Self { flags: CREATE_SUSPENDED }
24    }
25
26    /// skip thread attach notifications (dangerous)
27    pub const fn skip_attach() -> Self {
28        Self { flags: THREAD_CREATE_FLAGS_SKIP_THREAD_ATTACH }
29    }
30
31    /// hide thread from debugger
32    pub const fn hide_from_debugger() -> Self {
33        Self { flags: THREAD_CREATE_FLAGS_HIDE_FROM_DEBUGGER }
34    }
35
36    /// combine flags
37    pub const fn with(self, other: Self) -> Self {
38        Self { flags: self.flags | other.flags }
39    }
40}
41
42/// options for remote thread creation
43#[derive(Debug, Clone, Copy)]
44pub struct RemoteThreadOptions {
45    pub flags: ThreadCreationFlags,
46    pub stack_size: usize,
47    pub create_suspended: bool,
48}
49
50impl Default for RemoteThreadOptions {
51    fn default() -> Self {
52        Self {
53            flags: ThreadCreationFlags::running(),
54            stack_size: 0, // default stack size
55            create_suspended: false,
56        }
57    }
58}
59
60impl RemoteThreadOptions {
61    pub fn suspended() -> Self {
62        Self {
63            flags: ThreadCreationFlags::suspended(),
64            stack_size: 0,
65            create_suspended: true,
66        }
67    }
68
69    pub fn with_stack_size(mut self, size: usize) -> Self {
70        self.stack_size = size;
71        self
72    }
73}
74
75/// wrapper for a remote thread handle
76pub struct RemoteThread {
77    handle: usize,
78    id: u32,
79    owns_handle: bool,
80}
81
82impl RemoteThread {
83    /// get the thread handle
84    pub fn handle(&self) -> usize {
85        self.handle
86    }
87
88    /// get the thread ID
89    pub fn id(&self) -> u32 {
90        self.id
91    }
92
93    /// wait for thread to complete
94    pub fn wait(&self, timeout_ms: u32) -> Result<()> {
95        let result = unsafe { WaitForSingleObject(self.handle, timeout_ms) };
96        if result == WAIT_OBJECT_0 {
97            Ok(())
98        } else {
99            Err(WraithError::RemoteThreadFailed {
100                reason: format!("wait failed with result {}", result),
101            })
102        }
103    }
104
105    /// wait indefinitely for thread to complete
106    pub fn wait_infinite(&self) -> Result<()> {
107        self.wait(INFINITE)
108    }
109
110    /// get exit code (returns None if thread is still running)
111    pub fn exit_code(&self) -> Result<Option<u32>> {
112        let mut exit_code: u32 = 0;
113        let result = unsafe { GetExitCodeThread(self.handle, &mut exit_code) };
114        if result == 0 {
115            return Err(WraithError::RemoteThreadFailed {
116                reason: "GetExitCodeThread failed".into(),
117            });
118        }
119
120        if exit_code == STILL_ACTIVE {
121            Ok(None)
122        } else {
123            Ok(Some(exit_code))
124        }
125    }
126
127    /// suspend the thread
128    pub fn suspend(&self) -> Result<u32> {
129        let result = unsafe { SuspendThread(self.handle) };
130        if result == u32::MAX {
131            Err(WraithError::ThreadSuspendResumeFailed {
132                reason: "SuspendThread failed".into(),
133            })
134        } else {
135            Ok(result)
136        }
137    }
138
139    /// resume the thread
140    pub fn resume(&self) -> Result<u32> {
141        let result = unsafe { ResumeThread(self.handle) };
142        if result == u32::MAX {
143            Err(WraithError::ThreadSuspendResumeFailed {
144                reason: "ResumeThread failed".into(),
145            })
146        } else {
147            Ok(result)
148        }
149    }
150
151    /// terminate the thread
152    pub fn terminate(&self, exit_code: u32) -> Result<()> {
153        let table = get_syscall_table()?;
154        let syscall = DirectSyscall::from_table(table, "NtTerminateThread")?;
155
156        // SAFETY: terminating a thread with valid handle
157        let status = unsafe { syscall.call2(self.handle, exit_code as usize) };
158
159        if nt_success(status) {
160            Ok(())
161        } else {
162            Err(WraithError::SyscallFailed {
163                name: "NtTerminateThread".into(),
164                status,
165            })
166        }
167    }
168
169    /// leak the handle (don't close on drop)
170    pub fn leak(mut self) -> usize {
171        self.owns_handle = false;
172        self.handle
173    }
174}
175
176impl Drop for RemoteThread {
177    fn drop(&mut self) {
178        if self.owns_handle && self.handle != 0 {
179            let _ = nt_close(self.handle);
180        }
181    }
182}
183
184// SAFETY: thread handle can be sent between threads
185unsafe impl Send for RemoteThread {}
186unsafe impl Sync for RemoteThread {}
187
188/// create a remote thread in the target process
189pub fn create_remote_thread(
190    process: &RemoteProcess,
191    start_address: usize,
192    parameter: usize,
193    options: RemoteThreadOptions,
194) -> Result<RemoteThread> {
195    // use NtCreateThreadEx for more control
196    let table = get_syscall_table()?;
197    let syscall = DirectSyscall::from_table(table, "NtCreateThreadEx")?;
198
199    let mut thread_handle: usize = 0;
200    let obj_attr = ObjectAttributes::new();
201
202    let create_flags = if options.create_suspended {
203        options.flags.flags | CREATE_SUSPENDED
204    } else {
205        options.flags.flags
206    };
207
208    // SAFETY: all pointers point to valid stack data
209    let status = unsafe {
210        syscall.call_many(&[
211            &mut thread_handle as *mut usize as usize, // ThreadHandle
212            THREAD_ALL_ACCESS as usize,                 // DesiredAccess
213            &obj_attr as *const _ as usize,             // ObjectAttributes
214            process.handle(),                           // ProcessHandle
215            start_address,                              // StartRoutine
216            parameter,                                  // Argument
217            create_flags as usize,                      // CreateFlags
218            0,                                          // ZeroBits
219            options.stack_size,                         // StackSize
220            0,                                          // MaximumStackSize
221            0,                                          // AttributeList
222        ])
223    };
224
225    if nt_success(status) {
226        // get thread ID
227        let tid = get_thread_id(thread_handle)?;
228        Ok(RemoteThread {
229            handle: thread_handle,
230            id: tid,
231            owns_handle: true,
232        })
233    } else {
234        Err(WraithError::RemoteThreadFailed {
235            reason: format!("NtCreateThreadEx failed with status {:#x}", status as u32),
236        })
237    }
238}
239
240/// create remote thread using Win32 API (simpler but more detectable)
241pub fn create_remote_thread_win32(
242    process: &RemoteProcess,
243    start_address: usize,
244    parameter: usize,
245    suspended: bool,
246) -> Result<RemoteThread> {
247    let mut thread_id: u32 = 0;
248    let flags = if suspended { CREATE_SUSPENDED } else { 0 };
249
250    let handle = unsafe {
251        CreateRemoteThread(
252            process.handle(),
253            core::ptr::null(),
254            0,
255            start_address,
256            parameter,
257            flags,
258            &mut thread_id,
259        )
260    };
261
262    if handle == 0 {
263        return Err(WraithError::RemoteThreadFailed {
264            reason: format!("CreateRemoteThread failed: {}", unsafe { GetLastError() }),
265        });
266    }
267
268    Ok(RemoteThread {
269        handle,
270        id: thread_id,
271        owns_handle: true,
272    })
273}
274
275fn get_thread_id(thread_handle: usize) -> Result<u32> {
276    let table = get_syscall_table()?;
277    let syscall = DirectSyscall::from_table(table, "NtQueryInformationThread")?;
278
279    #[repr(C)]
280    struct ThreadBasicInfo {
281        exit_status: i32,
282        teb_base: usize,
283        client_id: ClientId,
284        affinity_mask: usize,
285        priority: i32,
286        base_priority: i32,
287    }
288
289    #[repr(C)]
290    struct ClientId {
291        unique_process: usize,
292        unique_thread: usize,
293    }
294
295    let mut info = core::mem::MaybeUninit::<ThreadBasicInfo>::uninit();
296    let mut return_length: u32 = 0;
297
298    // SAFETY: buffer is correctly sized
299    let status = unsafe {
300        syscall.call5(
301            thread_handle,
302            0, // ThreadBasicInformation
303            info.as_mut_ptr() as usize,
304            core::mem::size_of::<ThreadBasicInfo>(),
305            &mut return_length as *mut u32 as usize,
306        )
307    };
308
309    if nt_success(status) {
310        let info = unsafe { info.assume_init() };
311        Ok(info.client_id.unique_thread as u32)
312    } else {
313        // fallback to GetThreadId if syscall fails
314        let tid = unsafe { GetThreadId(thread_handle) };
315        if tid != 0 {
316            Ok(tid)
317        } else {
318            Err(WraithError::RemoteThreadFailed {
319                reason: "failed to get thread ID".into(),
320            })
321        }
322    }
323}
324
325// thread access rights
326const THREAD_ALL_ACCESS: u32 = 0x1F03FF;
327
328// thread creation flags
329const CREATE_SUSPENDED: u32 = 0x00000004;
330const THREAD_CREATE_FLAGS_SKIP_THREAD_ATTACH: u32 = 0x00000002;
331const THREAD_CREATE_FLAGS_HIDE_FROM_DEBUGGER: u32 = 0x00000004;
332
333// wait constants
334const WAIT_OBJECT_0: u32 = 0;
335const INFINITE: u32 = 0xFFFFFFFF;
336const STILL_ACTIVE: u32 = 259;
337
338#[link(name = "kernel32")]
339extern "system" {
340    fn CreateRemoteThread(
341        hProcess: usize,
342        lpThreadAttributes: *const core::ffi::c_void,
343        dwStackSize: usize,
344        lpStartAddress: usize,
345        lpParameter: usize,
346        dwCreationFlags: u32,
347        lpThreadId: *mut u32,
348    ) -> usize;
349
350    fn WaitForSingleObject(hHandle: usize, dwMilliseconds: u32) -> u32;
351    fn GetExitCodeThread(hThread: usize, lpExitCode: *mut u32) -> i32;
352    fn SuspendThread(hThread: usize) -> u32;
353    fn ResumeThread(hThread: usize) -> u32;
354    fn GetThreadId(Thread: usize) -> u32;
355    fn GetLastError() -> u32;
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_thread_creation_flags() {
364        let flags = ThreadCreationFlags::suspended();
365        assert_eq!(flags.flags, CREATE_SUSPENDED);
366
367        let combined = ThreadCreationFlags::suspended()
368            .with(ThreadCreationFlags::hide_from_debugger());
369        assert!(combined.flags & CREATE_SUSPENDED != 0);
370    }
371}