wraith/navigation/
thread_iter.rs1#[cfg(all(not(feature = "std"), feature = "alloc"))]
4use alloc::vec::Vec;
5
6#[cfg(feature = "std")]
7use std::vec::Vec;
8
9use crate::arch::segment;
10use crate::error::{Result, WraithError};
11
12#[derive(Debug, Clone)]
14pub struct ThreadInfo {
15 pub thread_id: u32,
16 pub teb_address: usize,
17 pub stack_base: usize,
18 pub stack_limit: usize,
19}
20
21impl ThreadInfo {
22 pub fn current() -> Result<Self> {
24 let teb = unsafe { segment::get_teb() };
26 if teb.is_null() {
27 return Err(WraithError::InvalidTebAccess);
28 }
29
30 #[cfg(target_arch = "x86_64")]
31 let (tid, stack_base, stack_limit) = {
32 let tid = unsafe { segment::get_current_tid() };
34 let stack_base = unsafe { *(teb.add(0x08) as *const u64) } as usize;
35 let stack_limit = unsafe { *(teb.add(0x10) as *const u64) } as usize;
36 (tid, stack_base, stack_limit)
37 };
38
39 #[cfg(target_arch = "x86")]
40 let (tid, stack_base, stack_limit) = {
41 let tid = unsafe { segment::get_current_tid() };
43 let stack_base = unsafe { *(teb.add(0x04) as *const u32) } as usize;
44 let stack_limit = unsafe { *(teb.add(0x08) as *const u32) } as usize;
45 (tid, stack_base, stack_limit)
46 };
47
48 Ok(Self {
49 thread_id: tid,
50 teb_address: teb as usize,
51 stack_base,
52 stack_limit,
53 })
54 }
55
56 pub fn is_on_stack(&self, address: usize) -> bool {
58 address >= self.stack_limit && address < self.stack_base
60 }
61}
62
63pub struct ThreadIterator {
65 snapshot: *mut core::ffi::c_void,
66 first: bool,
67 target_pid: u32,
68}
69
70impl ThreadIterator {
71 pub fn new() -> Result<Self> {
73 let current_pid = unsafe { segment::get_current_pid() };
75 Self::for_process(current_pid)
76 }
77
78 pub fn for_process(pid: u32) -> Result<Self> {
80 let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) };
82
83 if snapshot == INVALID_HANDLE_VALUE {
84 return Err(WraithError::from_last_error("CreateToolhelp32Snapshot"));
85 }
86
87 Ok(Self {
88 snapshot,
89 first: true,
90 target_pid: pid,
91 })
92 }
93}
94
95impl Iterator for ThreadIterator {
96 type Item = ThreadEntry;
97
98 fn next(&mut self) -> Option<Self::Item> {
99 let mut entry = ThreadEntry32 {
100 size: core::mem::size_of::<ThreadEntry32>() as u32,
101 ..Default::default()
102 };
103
104 loop {
105 let success = if self.first {
107 self.first = false;
108 unsafe { Thread32First(self.snapshot, &mut entry) }
109 } else {
110 unsafe { Thread32Next(self.snapshot, &mut entry) }
111 };
112
113 if success == 0 {
114 return None;
115 }
116
117 if entry.owner_process_id == self.target_pid {
119 return Some(ThreadEntry {
120 thread_id: entry.thread_id,
121 owner_process_id: entry.owner_process_id,
122 base_priority: entry.base_priority,
123 });
124 }
125 }
126 }
127}
128
129impl Drop for ThreadIterator {
130 fn drop(&mut self) {
131 if self.snapshot != INVALID_HANDLE_VALUE {
132 unsafe {
134 CloseHandle(self.snapshot);
135 }
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct ThreadEntry {
143 pub thread_id: u32,
144 pub owner_process_id: u32,
145 pub base_priority: i32,
146}
147
148#[repr(C)]
150#[derive(Default)]
151struct ThreadEntry32 {
152 size: u32,
153 usage: u32,
154 thread_id: u32,
155 owner_process_id: u32,
156 base_priority: i32,
157 delta_priority: i32,
158 flags: u32,
159}
160
161const TH32CS_SNAPTHREAD: u32 = 0x00000004;
162const INVALID_HANDLE_VALUE: *mut core::ffi::c_void = -1isize as *mut _;
163
164#[link(name = "kernel32")]
165extern "system" {
166 fn CreateToolhelp32Snapshot(flags: u32, process_id: u32) -> *mut core::ffi::c_void;
167 fn Thread32First(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
168 fn Thread32Next(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
169 fn CloseHandle(handle: *mut core::ffi::c_void) -> i32;
170}
171
172pub fn get_thread_ids() -> Result<Vec<u32>> {
174 Ok(ThreadIterator::new()?.map(|t| t.thread_id).collect())
175}
176
177pub fn thread_count() -> Result<usize> {
179 Ok(ThreadIterator::new()?.count())
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_current_thread() {
188 let info = ThreadInfo::current().expect("should get thread info");
189 assert!(info.thread_id > 0);
190 assert!(info.stack_base > info.stack_limit);
191 }
192
193 #[test]
194 fn test_thread_iterator() {
195 let threads: Vec<_> = ThreadIterator::new()
196 .expect("should create iterator")
197 .collect();
198
199 assert!(!threads.is_empty());
201 }
202
203 #[test]
204 fn test_get_thread_ids() {
205 let ids = get_thread_ids().expect("should get thread ids");
206 assert!(!ids.is_empty());
207
208 let current = ThreadInfo::current().expect("should get current thread");
210 assert!(ids.contains(¤t.thread_id));
211 }
212
213 #[test]
214 fn test_thread_count() {
215 let count = thread_count().expect("should get thread count");
216 assert!(count >= 1);
217 }
218}