wraith/navigation/
thread_iter.rs1use crate::arch::segment;
4use crate::error::{Result, WraithError};
5
6#[derive(Debug, Clone)]
8pub struct ThreadInfo {
9 pub thread_id: u32,
10 pub teb_address: usize,
11 pub stack_base: usize,
12 pub stack_limit: usize,
13}
14
15impl ThreadInfo {
16 pub fn current() -> Result<Self> {
18 let teb = unsafe { segment::get_teb() };
20 if teb.is_null() {
21 return Err(WraithError::InvalidTebAccess);
22 }
23
24 #[cfg(target_arch = "x86_64")]
25 let (tid, stack_base, stack_limit) = {
26 let tid = unsafe { segment::get_current_tid() };
28 let stack_base = unsafe { *(teb.add(0x08) as *const u64) } as usize;
29 let stack_limit = unsafe { *(teb.add(0x10) as *const u64) } as usize;
30 (tid, stack_base, stack_limit)
31 };
32
33 #[cfg(target_arch = "x86")]
34 let (tid, stack_base, stack_limit) = {
35 let tid = unsafe { segment::get_current_tid() };
37 let stack_base = unsafe { *(teb.add(0x04) as *const u32) } as usize;
38 let stack_limit = unsafe { *(teb.add(0x08) as *const u32) } as usize;
39 (tid, stack_base, stack_limit)
40 };
41
42 Ok(Self {
43 thread_id: tid,
44 teb_address: teb as usize,
45 stack_base,
46 stack_limit,
47 })
48 }
49
50 pub fn is_on_stack(&self, address: usize) -> bool {
52 address >= self.stack_limit && address < self.stack_base
54 }
55}
56
57pub struct ThreadIterator {
59 snapshot: *mut core::ffi::c_void,
60 first: bool,
61 target_pid: u32,
62}
63
64impl ThreadIterator {
65 pub fn new() -> Result<Self> {
67 let current_pid = unsafe { segment::get_current_pid() };
69 Self::for_process(current_pid)
70 }
71
72 pub fn for_process(pid: u32) -> Result<Self> {
74 let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) };
76
77 if snapshot == INVALID_HANDLE_VALUE {
78 return Err(WraithError::from_last_error("CreateToolhelp32Snapshot"));
79 }
80
81 Ok(Self {
82 snapshot,
83 first: true,
84 target_pid: pid,
85 })
86 }
87}
88
89impl Iterator for ThreadIterator {
90 type Item = ThreadEntry;
91
92 fn next(&mut self) -> Option<Self::Item> {
93 let mut entry = ThreadEntry32 {
94 size: core::mem::size_of::<ThreadEntry32>() as u32,
95 ..Default::default()
96 };
97
98 loop {
99 let success = if self.first {
101 self.first = false;
102 unsafe { Thread32First(self.snapshot, &mut entry) }
103 } else {
104 unsafe { Thread32Next(self.snapshot, &mut entry) }
105 };
106
107 if success == 0 {
108 return None;
109 }
110
111 if entry.owner_process_id == self.target_pid {
113 return Some(ThreadEntry {
114 thread_id: entry.thread_id,
115 owner_process_id: entry.owner_process_id,
116 base_priority: entry.base_priority,
117 });
118 }
119 }
120 }
121}
122
123impl Drop for ThreadIterator {
124 fn drop(&mut self) {
125 if self.snapshot != INVALID_HANDLE_VALUE {
126 unsafe {
128 CloseHandle(self.snapshot);
129 }
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct ThreadEntry {
137 pub thread_id: u32,
138 pub owner_process_id: u32,
139 pub base_priority: i32,
140}
141
142#[repr(C)]
144#[derive(Default)]
145struct ThreadEntry32 {
146 size: u32,
147 usage: u32,
148 thread_id: u32,
149 owner_process_id: u32,
150 base_priority: i32,
151 delta_priority: i32,
152 flags: u32,
153}
154
155const TH32CS_SNAPTHREAD: u32 = 0x00000004;
156const INVALID_HANDLE_VALUE: *mut core::ffi::c_void = -1isize as *mut _;
157
158#[link(name = "kernel32")]
159extern "system" {
160 fn CreateToolhelp32Snapshot(flags: u32, process_id: u32) -> *mut core::ffi::c_void;
161 fn Thread32First(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
162 fn Thread32Next(snapshot: *mut core::ffi::c_void, entry: *mut ThreadEntry32) -> i32;
163 fn CloseHandle(handle: *mut core::ffi::c_void) -> i32;
164}
165
166pub fn get_thread_ids() -> Result<Vec<u32>> {
168 Ok(ThreadIterator::new()?.map(|t| t.thread_id).collect())
169}
170
171pub fn thread_count() -> Result<usize> {
173 Ok(ThreadIterator::new()?.count())
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_current_thread() {
182 let info = ThreadInfo::current().expect("should get thread info");
183 assert!(info.thread_id > 0);
184 assert!(info.stack_base > info.stack_limit);
185 }
186
187 #[test]
188 fn test_thread_iterator() {
189 let threads: Vec<_> = ThreadIterator::new()
190 .expect("should create iterator")
191 .collect();
192
193 assert!(!threads.is_empty());
195 }
196
197 #[test]
198 fn test_get_thread_ids() {
199 let ids = get_thread_ids().expect("should get thread ids");
200 assert!(!ids.is_empty());
201
202 let current = ThreadInfo::current().expect("should get current thread");
204 assert!(ids.contains(¤t.thread_id));
205 }
206
207 #[test]
208 fn test_thread_count() {
209 let count = thread_count().expect("should get thread count");
210 assert!(count >= 1);
211 }
212}