wraith/manipulation/remote/
thread.rs

1//! Remote thread creation and manipulation
2
3#[cfg(all(not(feature = "std"), feature = "alloc"))]
4use alloc::{format, string::String};
5
6#[cfg(feature = "std")]
7use std::{format, string::String};
8
9use super::process::RemoteProcess;
10use crate::error::{Result, WraithError};
11use crate::manipulation::syscall::{
12    get_syscall_table, nt_close, nt_success, DirectSyscall, ObjectAttributes,
13};
14
15/// thread creation flags
16#[derive(Debug, Clone, Copy, Default)]
17pub struct ThreadCreationFlags {
18    pub flags: u32,
19}
20
21impl ThreadCreationFlags {
22    /// create thread in running state
23    pub const fn running() -> Self {
24        Self { flags: 0 }
25    }
26
27    /// create thread in suspended state
28    pub const fn suspended() -> Self {
29        Self { flags: CREATE_SUSPENDED }
30    }
31
32    /// skip thread attach notifications (dangerous)
33    pub const fn skip_attach() -> Self {
34        Self { flags: THREAD_CREATE_FLAGS_SKIP_THREAD_ATTACH }
35    }
36
37    /// hide thread from debugger
38    pub const fn hide_from_debugger() -> Self {
39        Self { flags: THREAD_CREATE_FLAGS_HIDE_FROM_DEBUGGER }
40    }
41
42    /// combine flags
43    pub const fn with(self, other: Self) -> Self {
44        Self { flags: self.flags | other.flags }
45    }
46}
47
48/// options for remote thread creation
49#[derive(Debug, Clone, Copy)]
50pub struct RemoteThreadOptions {
51    pub flags: ThreadCreationFlags,
52    pub stack_size: usize,
53    pub create_suspended: bool,
54}
55
56impl Default for RemoteThreadOptions {
57    fn default() -> Self {
58        Self {
59            flags: ThreadCreationFlags::running(),
60            stack_size: 0, // default stack size
61            create_suspended: false,
62        }
63    }
64}
65
66impl RemoteThreadOptions {
67    pub fn suspended() -> Self {
68        Self {
69            flags: ThreadCreationFlags::suspended(),
70            stack_size: 0,
71            create_suspended: true,
72        }
73    }
74
75    pub fn with_stack_size(mut self, size: usize) -> Self {
76        self.stack_size = size;
77        self
78    }
79}
80
81/// wrapper for a remote thread handle
82pub struct RemoteThread {
83    handle: usize,
84    id: u32,
85    owns_handle: bool,
86}
87
88impl RemoteThread {
89    /// get the thread handle
90    pub fn handle(&self) -> usize {
91        self.handle
92    }
93
94    /// get the thread ID
95    pub fn id(&self) -> u32 {
96        self.id
97    }
98
99    /// wait for thread to complete
100    pub fn wait(&self, timeout_ms: u32) -> Result<()> {
101        let result = unsafe { WaitForSingleObject(self.handle, timeout_ms) };
102        if result == WAIT_OBJECT_0 {
103            Ok(())
104        } else {
105            Err(WraithError::RemoteThreadFailed {
106                reason: format!("wait failed with result {}", result),
107            })
108        }
109    }
110
111    /// wait indefinitely for thread to complete
112    pub fn wait_infinite(&self) -> Result<()> {
113        self.wait(INFINITE)
114    }
115
116    /// get exit code (returns None if thread is still running)
117    pub fn exit_code(&self) -> Result<Option<u32>> {
118        let mut exit_code: u32 = 0;
119        let result = unsafe { GetExitCodeThread(self.handle, &mut exit_code) };
120        if result == 0 {
121            return Err(WraithError::RemoteThreadFailed {
122                reason: "GetExitCodeThread failed".into(),
123            });
124        }
125
126        if exit_code == STILL_ACTIVE {
127            Ok(None)
128        } else {
129            Ok(Some(exit_code))
130        }
131    }
132
133    /// suspend the thread
134    pub fn suspend(&self) -> Result<u32> {
135        let result = unsafe { SuspendThread(self.handle) };
136        if result == u32::MAX {
137            Err(WraithError::ThreadSuspendResumeFailed {
138                reason: "SuspendThread failed".into(),
139            })
140        } else {
141            Ok(result)
142        }
143    }
144
145    /// resume the thread
146    pub fn resume(&self) -> Result<u32> {
147        let result = unsafe { ResumeThread(self.handle) };
148        if result == u32::MAX {
149            Err(WraithError::ThreadSuspendResumeFailed {
150                reason: "ResumeThread failed".into(),
151            })
152        } else {
153            Ok(result)
154        }
155    }
156
157    /// terminate the thread
158    pub fn terminate(&self, exit_code: u32) -> Result<()> {
159        let table = get_syscall_table()?;
160        let syscall = DirectSyscall::from_table(table, "NtTerminateThread")?;
161
162        // SAFETY: terminating a thread with valid handle
163        let status = unsafe { syscall.call2(self.handle, exit_code as usize) };
164
165        if nt_success(status) {
166            Ok(())
167        } else {
168            Err(WraithError::SyscallFailed {
169                name: "NtTerminateThread".into(),
170                status,
171            })
172        }
173    }
174
175    /// leak the handle (don't close on drop)
176    pub fn leak(mut self) -> usize {
177        self.owns_handle = false;
178        self.handle
179    }
180}
181
182impl Drop for RemoteThread {
183    fn drop(&mut self) {
184        if self.owns_handle && self.handle != 0 {
185            let _ = nt_close(self.handle);
186        }
187    }
188}
189
190// SAFETY: thread handle can be sent between threads
191unsafe impl Send for RemoteThread {}
192unsafe impl Sync for RemoteThread {}
193
194/// create a remote thread in the target process
195pub fn create_remote_thread(
196    process: &RemoteProcess,
197    start_address: usize,
198    parameter: usize,
199    options: RemoteThreadOptions,
200) -> Result<RemoteThread> {
201    // use NtCreateThreadEx for more control
202    let table = get_syscall_table()?;
203    let syscall = DirectSyscall::from_table(table, "NtCreateThreadEx")?;
204
205    let mut thread_handle: usize = 0;
206    let obj_attr = ObjectAttributes::new();
207
208    let create_flags = if options.create_suspended {
209        options.flags.flags | CREATE_SUSPENDED
210    } else {
211        options.flags.flags
212    };
213
214    // SAFETY: all pointers point to valid stack data
215    let status = unsafe {
216        syscall.call_many(&[
217            &mut thread_handle as *mut usize as usize, // ThreadHandle
218            THREAD_ALL_ACCESS as usize,                 // DesiredAccess
219            &obj_attr as *const _ as usize,             // ObjectAttributes
220            process.handle(),                           // ProcessHandle
221            start_address,                              // StartRoutine
222            parameter,                                  // Argument
223            create_flags as usize,                      // CreateFlags
224            0,                                          // ZeroBits
225            options.stack_size,                         // StackSize
226            0,                                          // MaximumStackSize
227            0,                                          // AttributeList
228        ])
229    };
230
231    if nt_success(status) {
232        // get thread ID
233        let tid = get_thread_id(thread_handle)?;
234        Ok(RemoteThread {
235            handle: thread_handle,
236            id: tid,
237            owns_handle: true,
238        })
239    } else {
240        Err(WraithError::RemoteThreadFailed {
241            reason: format!("NtCreateThreadEx failed with status {:#x}", status as u32),
242        })
243    }
244}
245
246/// create remote thread using Win32 API (simpler but more detectable)
247pub fn create_remote_thread_win32(
248    process: &RemoteProcess,
249    start_address: usize,
250    parameter: usize,
251    suspended: bool,
252) -> Result<RemoteThread> {
253    let mut thread_id: u32 = 0;
254    let flags = if suspended { CREATE_SUSPENDED } else { 0 };
255
256    let handle = unsafe {
257        CreateRemoteThread(
258            process.handle(),
259            core::ptr::null(),
260            0,
261            start_address,
262            parameter,
263            flags,
264            &mut thread_id,
265        )
266    };
267
268    if handle == 0 {
269        return Err(WraithError::RemoteThreadFailed {
270            reason: format!("CreateRemoteThread failed: {}", unsafe { GetLastError() }),
271        });
272    }
273
274    Ok(RemoteThread {
275        handle,
276        id: thread_id,
277        owns_handle: true,
278    })
279}
280
281fn get_thread_id(thread_handle: usize) -> Result<u32> {
282    let table = get_syscall_table()?;
283    let syscall = DirectSyscall::from_table(table, "NtQueryInformationThread")?;
284
285    #[repr(C)]
286    struct ThreadBasicInfo {
287        exit_status: i32,
288        teb_base: usize,
289        client_id: ClientId,
290        affinity_mask: usize,
291        priority: i32,
292        base_priority: i32,
293    }
294
295    #[repr(C)]
296    struct ClientId {
297        unique_process: usize,
298        unique_thread: usize,
299    }
300
301    let mut info = core::mem::MaybeUninit::<ThreadBasicInfo>::uninit();
302    let mut return_length: u32 = 0;
303
304    // SAFETY: buffer is correctly sized
305    let status = unsafe {
306        syscall.call5(
307            thread_handle,
308            0, // ThreadBasicInformation
309            info.as_mut_ptr() as usize,
310            core::mem::size_of::<ThreadBasicInfo>(),
311            &mut return_length as *mut u32 as usize,
312        )
313    };
314
315    if nt_success(status) {
316        let info = unsafe { info.assume_init() };
317        Ok(info.client_id.unique_thread as u32)
318    } else {
319        // fallback to GetThreadId if syscall fails
320        let tid = unsafe { GetThreadId(thread_handle) };
321        if tid != 0 {
322            Ok(tid)
323        } else {
324            Err(WraithError::RemoteThreadFailed {
325                reason: "failed to get thread ID".into(),
326            })
327        }
328    }
329}
330
331// thread access rights
332const THREAD_ALL_ACCESS: u32 = 0x1F03FF;
333
334// thread creation flags
335const CREATE_SUSPENDED: u32 = 0x00000004;
336const THREAD_CREATE_FLAGS_SKIP_THREAD_ATTACH: u32 = 0x00000002;
337const THREAD_CREATE_FLAGS_HIDE_FROM_DEBUGGER: u32 = 0x00000004;
338
339// wait constants
340const WAIT_OBJECT_0: u32 = 0;
341const INFINITE: u32 = 0xFFFFFFFF;
342const STILL_ACTIVE: u32 = 259;
343
344#[link(name = "kernel32")]
345extern "system" {
346    fn CreateRemoteThread(
347        hProcess: usize,
348        lpThreadAttributes: *const core::ffi::c_void,
349        dwStackSize: usize,
350        lpStartAddress: usize,
351        lpParameter: usize,
352        dwCreationFlags: u32,
353        lpThreadId: *mut u32,
354    ) -> usize;
355
356    fn WaitForSingleObject(hHandle: usize, dwMilliseconds: u32) -> u32;
357    fn GetExitCodeThread(hThread: usize, lpExitCode: *mut u32) -> i32;
358    fn SuspendThread(hThread: usize) -> u32;
359    fn ResumeThread(hThread: usize) -> u32;
360    fn GetThreadId(Thread: usize) -> u32;
361    fn GetLastError() -> u32;
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_thread_creation_flags() {
370        let flags = ThreadCreationFlags::suspended();
371        assert_eq!(flags.flags, CREATE_SUSPENDED);
372
373        let combined = ThreadCreationFlags::suspended()
374            .with(ThreadCreationFlags::hide_from_debugger());
375        assert!(combined.flags & CREATE_SUSPENDED != 0);
376    }
377}