wraith/manipulation/syscall/
enumerator.rs

1//! Syscall Service Number enumeration from ntdll
2//!
3//! Extracts SSNs by parsing ntdll export directory and reading
4//! the syscall stub prologues to find the mov eax, imm32 instruction.
5
6use crate::error::{Result, WraithError};
7use crate::navigation::{Module, ModuleQuery};
8use crate::structures::pe::{DataDirectoryType, ExportDirectory};
9use crate::structures::Peb;
10use crate::util::hash::djb2_hash;
11
12/// syscall stub patterns for detection
13mod patterns {
14    // x64 syscall stub pattern:
15    // 4C 8B D1        mov r10, rcx
16    // B8 XX XX 00 00  mov eax, <ssn>
17    // ...
18    // 0F 05           syscall
19    // C3              ret
20    #[cfg(target_arch = "x86_64")]
21    pub const MOV_R10_RCX: [u8; 3] = [0x4C, 0x8B, 0xD1];
22
23    pub const MOV_EAX: u8 = 0xB8;
24
25    #[cfg(target_arch = "x86_64")]
26    pub const SYSCALL: [u8; 2] = [0x0F, 0x05];
27
28    // x86 syscall stub patterns
29    #[cfg(target_arch = "x86")]
30    pub const INT_2E: [u8; 2] = [0xCD, 0x2E];
31
32    #[cfg(target_arch = "x86")]
33    pub const SYSENTER: [u8; 2] = [0x0F, 0x34];
34}
35
36/// enumerates syscalls from ntdll exports
37pub struct SyscallEnumerator<'a> {
38    ntdll: Module<'a>,
39}
40
41impl<'a> SyscallEnumerator<'a> {
42    /// create enumerator for ntdll
43    pub fn new(ntdll: Module<'a>) -> Self {
44        Self { ntdll }
45    }
46
47    /// enumerate all syscalls and their SSNs
48    pub fn enumerate(&self) -> Result<Vec<EnumeratedSyscall>> {
49        let mut syscalls = Vec::new();
50
51        let nt = self.ntdll.nt_headers()?;
52        let export_dir = nt
53            .data_directory(DataDirectoryType::Export.index())
54            .ok_or(WraithError::SyscallEnumerationFailed {
55                reason: "no export directory".into(),
56            })?;
57
58        if !export_dir.is_present() {
59            return Err(WraithError::SyscallEnumerationFailed {
60                reason: "export directory not present".into(),
61            });
62        }
63
64        let base = self.ntdll.base();
65        // SAFETY: export directory RVA points to valid memory in loaded ntdll
66        let exports = unsafe {
67            &*((base + export_dir.virtual_address as usize) as *const ExportDirectory)
68        };
69
70        let num_names = exports.number_of_names as usize;
71        let names = base + exports.address_of_names as usize;
72        let ordinals = base + exports.address_of_name_ordinals as usize;
73        let functions = base + exports.address_of_functions as usize;
74
75        for i in 0..num_names {
76            // SAFETY: iterating within bounds of export arrays
77            let name_rva = unsafe { *((names + i * 4) as *const u32) };
78            let name_ptr = (base + name_rva as usize) as *const u8;
79
80            // read function name with bounds checking
81            let name = unsafe {
82                let mut len = 0;
83                while *name_ptr.add(len) != 0 && len < 256 {
84                    len += 1;
85                }
86                let bytes = core::slice::from_raw_parts(name_ptr, len);
87                match core::str::from_utf8(bytes) {
88                    Ok(s) => s,
89                    Err(_) => continue, // skip invalid UTF-8
90                }
91            };
92
93            // only process Nt/Zw functions (syscalls)
94            if !name.starts_with("Nt") && !name.starts_with("Zw") {
95                continue;
96            }
97
98            // skip Nt functions that aren't syscalls (they're just accessors)
99            if matches!(
100                name,
101                "NtCurrentTeb"
102                    | "NtCurrentPeb"
103                    | "NtGetTickCount"
104                    | "NtdllDefWindowProc_A"
105                    | "NtdllDefWindowProc_W"
106                    | "NtdllDialogWndProc_A"
107                    | "NtdllDialogWndProc_W"
108            ) {
109                continue;
110            }
111
112            let ordinal = unsafe { *((ordinals + i * 2) as *const u16) };
113            let func_rva = unsafe { *((functions + ordinal as usize * 4) as *const u32) };
114            let func_addr = base + func_rva as usize;
115
116            // check for forwarded export
117            if func_rva >= export_dir.virtual_address
118                && func_rva < export_dir.virtual_address + export_dir.size
119            {
120                continue;
121            }
122
123            // try to extract SSN from the stub
124            if let Some(ssn) = self.extract_ssn(func_addr) {
125                syscalls.push(EnumeratedSyscall {
126                    name: name.to_string(),
127                    name_hash: djb2_hash(name.as_bytes()),
128                    ssn,
129                    address: func_addr,
130                    syscall_address: self.find_syscall_instruction(func_addr),
131                });
132            }
133        }
134
135        // sort by SSN (they should be sequential)
136        syscalls.sort_by_key(|s| s.ssn);
137
138        Ok(syscalls)
139    }
140
141    /// extract SSN from syscall stub (x64)
142    #[cfg(target_arch = "x86_64")]
143    fn extract_ssn(&self, addr: usize) -> Option<u16> {
144        // SAFETY: reading from function address in loaded ntdll
145        let bytes = unsafe { core::slice::from_raw_parts(addr as *const u8, 32) };
146
147        // standard pattern: 4C 8B D1 B8 XX XX 00 00
148        if bytes.len() >= 8
149            && bytes[0..3] == patterns::MOV_R10_RCX
150            && bytes[3] == patterns::MOV_EAX
151        {
152            let ssn = u16::from_le_bytes([bytes[4], bytes[5]]);
153            return Some(ssn);
154        }
155
156        // hooked stub might have different prologue - scan for mov eax pattern
157        for i in 0..20 {
158            if i + 2 < bytes.len() && bytes[i] == patterns::MOV_EAX {
159                let ssn = u16::from_le_bytes([bytes[i + 1], bytes[i + 2]]);
160                if ssn < 0x1000 {
161                    return Some(ssn);
162                }
163            }
164        }
165
166        None
167    }
168
169    /// extract SSN from syscall stub (x86)
170    #[cfg(target_arch = "x86")]
171    fn extract_ssn(&self, addr: usize) -> Option<u16> {
172        // SAFETY: reading from function address in loaded ntdll
173        let bytes = unsafe { core::slice::from_raw_parts(addr as *const u8, 32) };
174
175        // pattern: B8 XX XX 00 00
176        if bytes.len() >= 5 && bytes[0] == patterns::MOV_EAX {
177            let ssn = u16::from_le_bytes([bytes[1], bytes[2]]);
178            return Some(ssn);
179        }
180
181        None
182    }
183
184    /// find syscall/sysenter instruction address in stub (x64)
185    #[cfg(target_arch = "x86_64")]
186    fn find_syscall_instruction(&self, func_addr: usize) -> Option<usize> {
187        // SAFETY: reading from function in loaded ntdll
188        let bytes = unsafe { core::slice::from_raw_parts(func_addr as *const u8, 32) };
189
190        // look for syscall (0F 05)
191        for i in 0..30 {
192            if i + 1 < bytes.len() && bytes[i..].starts_with(&patterns::SYSCALL) {
193                return Some(func_addr + i);
194            }
195        }
196
197        None
198    }
199
200    /// find syscall/sysenter instruction address in stub (x86)
201    #[cfg(target_arch = "x86")]
202    fn find_syscall_instruction(&self, func_addr: usize) -> Option<usize> {
203        // SAFETY: reading from function in loaded ntdll
204        let bytes = unsafe { core::slice::from_raw_parts(func_addr as *const u8, 64) };
205
206        // look for int 0x2e or sysenter
207        for i in 0..60 {
208            if i + 1 < bytes.len()
209                && (bytes[i..].starts_with(&patterns::INT_2E)
210                    || bytes[i..].starts_with(&patterns::SYSENTER))
211            {
212                return Some(func_addr + i);
213            }
214        }
215
216        None
217    }
218
219    /// resolve SSN using "Halo's Gate" technique
220    ///
221    /// if a syscall is hooked, look at neighboring syscalls
222    /// (SSNs are sequential, so Nt* functions nearby have SSN +/- N)
223    #[allow(dead_code)]
224    pub fn resolve_hooked_ssn(&self, target_addr: usize) -> Option<u16> {
225        // search upward (earlier functions have lower SSNs)
226        for offset in 1..=20u16 {
227            // typical syscall stub size is ~32 bytes
228            let check_addr = target_addr.wrapping_sub(offset as usize * 32);
229            if let Some(ssn) = self.extract_ssn(check_addr) {
230                return Some(ssn.wrapping_add(offset));
231            }
232        }
233
234        // search downward (later functions have higher SSNs)
235        for offset in 1..=20u16 {
236            let check_addr = target_addr + (offset as usize * 32);
237            if let Some(ssn) = self.extract_ssn(check_addr) {
238                return ssn.checked_sub(offset);
239            }
240        }
241
242        None
243    }
244}
245
246/// enumerated syscall information
247#[derive(Debug, Clone)]
248pub struct EnumeratedSyscall {
249    /// function name (e.g., "NtOpenProcess")
250    pub name: String,
251    /// hash of function name for fast lookup
252    pub name_hash: u32,
253    /// syscall service number
254    pub ssn: u16,
255    /// address in ntdll
256    pub address: usize,
257    /// address of syscall instruction (for indirect calls)
258    pub syscall_address: Option<usize>,
259}
260
261/// enumerate syscalls from current process's ntdll
262pub fn enumerate_syscalls() -> Result<Vec<EnumeratedSyscall>> {
263    let peb = Peb::current()?;
264    let query = ModuleQuery::new(&peb);
265    let ntdll = query.ntdll().map_err(|_| WraithError::NtdllNotFound)?;
266
267    let enumerator = SyscallEnumerator::new(ntdll);
268    enumerator.enumerate()
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_enumerate_syscalls() {
277        let syscalls = enumerate_syscalls().expect("should enumerate syscalls");
278        assert!(!syscalls.is_empty(), "should find at least some syscalls");
279
280        // should have NtClose
281        let nt_close = syscalls.iter().find(|s| s.name == "NtClose");
282        assert!(nt_close.is_some(), "should find NtClose");
283
284        // SSN should be reasonable (< 0x500 on most Windows versions)
285        let close = nt_close.unwrap();
286        assert!(close.ssn < 0x500, "NtClose SSN should be reasonable");
287    }
288
289    #[test]
290    fn test_ssn_ordering() {
291        let syscalls = enumerate_syscalls().expect("should enumerate syscalls");
292
293        // SSNs should be sorted after enumeration
294        for i in 1..syscalls.len() {
295            assert!(
296                syscalls[i].ssn >= syscalls[i - 1].ssn,
297                "SSNs should be sorted"
298            );
299        }
300    }
301}