1use crate::error::{Result, WraithError};
7use crate::navigation::ModuleQuery;
8use crate::structures::Peb;
9use std::collections::HashMap;
10use std::sync::OnceLock;
11
12static GADGET_CACHE: OnceLock<Result<GadgetCache>> = OnceLock::new();
14
15pub fn init_global_cache() -> Result<()> {
17 let result = GADGET_CACHE.get_or_init(GadgetCache::build);
18 match result {
19 Ok(_) => Ok(()),
20 Err(e) => Err(WraithError::SyscallEnumerationFailed {
21 reason: format!("failed to build gadget cache: {}", e),
22 }),
23 }
24}
25
26pub fn get_global_cache() -> Result<&'static GadgetCache> {
28 let result = GADGET_CACHE.get_or_init(GadgetCache::build);
29 match result {
30 Ok(cache) => Ok(cache),
31 Err(e) => Err(WraithError::SyscallEnumerationFailed {
32 reason: format!("failed to get gadget cache: {}", e),
33 }),
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum GadgetType {
40 JmpRbx,
42 JmpRax,
44 JmpRcx,
46 JmpRdx,
48 JmpR8,
50 JmpR9,
52 JmpIndirectRbx,
54 JmpIndirectRax,
56 CallRbx,
58 CallRax,
60 Ret,
62 AddRspRet { offset: u8 },
64 PopRet { register: u8 },
66 PushRbxRet,
68}
69
70impl GadgetType {
71 #[cfg(target_arch = "x86_64")]
73 pub fn bytes(&self) -> &'static [u8] {
74 match self {
75 Self::JmpRbx => &[0xFF, 0xE3], Self::JmpRax => &[0xFF, 0xE0], Self::JmpRcx => &[0xFF, 0xE1], Self::JmpRdx => &[0xFF, 0xE2], Self::JmpR8 => &[0x41, 0xFF, 0xE0], Self::JmpR9 => &[0x41, 0xFF, 0xE1], Self::JmpIndirectRbx => &[0xFF, 0x23], Self::JmpIndirectRax => &[0xFF, 0x20], Self::CallRbx => &[0xFF, 0xD3], Self::CallRax => &[0xFF, 0xD0], Self::Ret => &[0xC3], Self::AddRspRet { .. } => &[], Self::PopRet { .. } => &[], Self::PushRbxRet => &[0x53, 0xC3], }
90 }
91
92 pub fn name(&self) -> &'static str {
94 match self {
95 Self::JmpRbx => "jmp rbx",
96 Self::JmpRax => "jmp rax",
97 Self::JmpRcx => "jmp rcx",
98 Self::JmpRdx => "jmp rdx",
99 Self::JmpR8 => "jmp r8",
100 Self::JmpR9 => "jmp r9",
101 Self::JmpIndirectRbx => "jmp [rbx]",
102 Self::JmpIndirectRax => "jmp [rax]",
103 Self::CallRbx => "call rbx",
104 Self::CallRax => "call rax",
105 Self::Ret => "ret",
106 Self::AddRspRet { offset: _ } => "add rsp, N; ret",
107 Self::PopRet { .. } => "pop reg; ret",
108 Self::PushRbxRet => "push rbx; ret",
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct Gadget {
116 pub address: usize,
118 pub gadget_type: GadgetType,
120 pub module_name: String,
122 pub module_offset: usize,
124 pub is_system_module: bool,
126}
127
128impl Gadget {
129 pub fn is_valid(&self) -> bool {
131 let bytes = self.gadget_type.bytes();
132 if bytes.is_empty() {
133 return true; }
135
136 let actual = unsafe { std::slice::from_raw_parts(self.address as *const u8, bytes.len()) };
138 actual == bytes
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct JmpGadget {
145 pub gadget: Gadget,
146}
147
148impl JmpGadget {
149 pub fn address(&self) -> usize {
150 self.gadget.address
151 }
152}
153
154#[derive(Debug, Clone)]
156pub struct RetGadget {
157 pub gadget: Gadget,
158 pub stack_adjustment: usize,
160}
161
162impl RetGadget {
163 pub fn address(&self) -> usize {
164 self.gadget.address
165 }
166}
167
168#[derive(Debug)]
170pub struct GadgetCache {
171 by_type: HashMap<GadgetType, Vec<Gadget>>,
173 by_module: HashMap<String, Vec<Gadget>>,
175 preferred_jmp_rbx: Option<Gadget>,
177 preferred_jmp_rax: Option<Gadget>,
179 preferred_ret: Option<Gadget>,
181}
182
183impl GadgetCache {
184 pub fn build() -> Result<Self> {
186 let finder = GadgetFinder::new()?;
187
188 let mut by_type: HashMap<GadgetType, Vec<Gadget>> = HashMap::new();
189 let mut by_module: HashMap<String, Vec<Gadget>> = HashMap::new();
190
191 let modules = ["ntdll.dll", "kernel32.dll", "kernelbase.dll"];
193
194 for module_name in modules {
195 if let Ok(gadgets) = finder.scan_module_all(module_name) {
196 for gadget in gadgets {
197 let module_lower = gadget.module_name.to_lowercase();
198
199 by_type
200 .entry(gadget.gadget_type)
201 .or_default()
202 .push(gadget.clone());
203
204 by_module.entry(module_lower).or_default().push(gadget);
205 }
206 }
207 }
208
209 let preferred_jmp_rbx = by_type
211 .get(&GadgetType::JmpRbx)
212 .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
213 .cloned();
214
215 let preferred_jmp_rax = by_type
216 .get(&GadgetType::JmpRax)
217 .and_then(|v| v.iter().find(|g| g.module_name.eq_ignore_ascii_case("ntdll.dll")))
218 .cloned();
219
220 let preferred_ret = by_type
221 .get(&GadgetType::Ret)
222 .and_then(|v| {
223 v.iter()
224 .find(|g| g.module_name.eq_ignore_ascii_case("kernel32.dll"))
225 })
226 .cloned();
227
228 Ok(Self {
229 by_type,
230 by_module,
231 preferred_jmp_rbx,
232 preferred_jmp_rax,
233 preferred_ret,
234 })
235 }
236
237 pub fn jmp_rbx(&self) -> Option<&Gadget> {
239 self.preferred_jmp_rbx.as_ref()
240 }
241
242 pub fn jmp_rax(&self) -> Option<&Gadget> {
244 self.preferred_jmp_rax.as_ref()
245 }
246
247 pub fn ret_gadget(&self) -> Option<&Gadget> {
249 self.preferred_ret.as_ref()
250 }
251
252 pub fn get_by_type(&self, gadget_type: GadgetType) -> &[Gadget] {
254 self.by_type.get(&gadget_type).map(|v| v.as_slice()).unwrap_or(&[])
255 }
256
257 pub fn get_by_module(&self, module_name: &str) -> &[Gadget] {
259 self.by_module
260 .get(&module_name.to_lowercase())
261 .map(|v| v.as_slice())
262 .unwrap_or(&[])
263 }
264
265 pub fn any_jmp_gadget(&self) -> Option<&Gadget> {
267 self.preferred_jmp_rbx
268 .as_ref()
269 .or(self.preferred_jmp_rax.as_ref())
270 .or_else(|| {
271 self.by_type
272 .get(&GadgetType::JmpRbx)
273 .and_then(|v| v.first())
274 })
275 .or_else(|| {
276 self.by_type
277 .get(&GadgetType::JmpRax)
278 .and_then(|v| v.first())
279 })
280 }
281}
282
283pub struct GadgetFinder {
285 peb: Peb,
286}
287
288impl GadgetFinder {
289 pub fn new() -> Result<Self> {
291 Ok(Self {
292 peb: Peb::current()?,
293 })
294 }
295
296 pub fn find_jmp_rbx(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
298 self.find_gadgets_of_type(module_name, GadgetType::JmpRbx)
299 .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
300 }
301
302 pub fn find_jmp_rax(&self, module_name: &str) -> Result<Vec<JmpGadget>> {
304 self.find_gadgets_of_type(module_name, GadgetType::JmpRax)
305 .map(|gadgets| gadgets.into_iter().map(|g| JmpGadget { gadget: g }).collect())
306 }
307
308 pub fn find_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
310 self.find_gadgets_of_type(module_name, GadgetType::Ret)
311 .map(|gadgets| {
312 gadgets
313 .into_iter()
314 .map(|g| RetGadget {
315 gadget: g,
316 stack_adjustment: 0,
317 })
318 .collect()
319 })
320 }
321
322 pub fn find_gadgets_of_type(
324 &self,
325 module_name: &str,
326 gadget_type: GadgetType,
327 ) -> Result<Vec<Gadget>> {
328 let query = ModuleQuery::new(&self.peb);
329 let module = query.find_by_name(module_name)?;
330
331 let bytes = gadget_type.bytes();
332 if bytes.is_empty() {
333 return Ok(Vec::new());
334 }
335
336 let base = module.base();
337 let size = module.size();
338 let name = module.name();
339 let is_system = is_system_module(&name);
340
341 let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
344
345 let mut gadgets = Vec::new();
346 let pattern_len = bytes.len();
347
348 for offset in 0..=(size.saturating_sub(pattern_len)) {
350 if &data[offset..offset + pattern_len] == bytes {
351 gadgets.push(Gadget {
352 address: base + offset,
353 gadget_type,
354 module_name: name.clone(),
355 module_offset: offset,
356 is_system_module: is_system,
357 });
358 }
359 }
360
361 Ok(gadgets)
362 }
363
364 pub fn find_add_rsp_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
366 let query = ModuleQuery::new(&self.peb);
367 let module = query.find_by_name(module_name)?;
368
369 let base = module.base();
370 let size = module.size();
371 let name = module.name();
372 let is_system = is_system_module(&name);
373
374 let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
376
377 let mut gadgets = Vec::new();
378
379 for offset in 0..=(size.saturating_sub(5)) {
382 if data[offset] == 0x48
383 && data[offset + 1] == 0x83
384 && data[offset + 2] == 0xC4
385 && data[offset + 4] == 0xC3
386 {
387 let stack_adj = data[offset + 3] as usize;
388 gadgets.push(RetGadget {
389 gadget: Gadget {
390 address: base + offset,
391 gadget_type: GadgetType::AddRspRet {
392 offset: data[offset + 3],
393 },
394 module_name: name.clone(),
395 module_offset: offset,
396 is_system_module: is_system,
397 },
398 stack_adjustment: stack_adj,
399 });
400 }
401 }
402
403 for offset in 0..=(size.saturating_sub(8)) {
406 if data[offset] == 0x48
407 && data[offset + 1] == 0x81
408 && data[offset + 2] == 0xC4
409 && data[offset + 7] == 0xC3
410 {
411 let stack_adj = u32::from_le_bytes([
412 data[offset + 3],
413 data[offset + 4],
414 data[offset + 5],
415 data[offset + 6],
416 ]) as usize;
417
418 gadgets.push(RetGadget {
419 gadget: Gadget {
420 address: base + offset,
421 gadget_type: GadgetType::AddRspRet {
422 offset: 0, },
424 module_name: name.clone(),
425 module_offset: offset,
426 is_system_module: is_system,
427 },
428 stack_adjustment: stack_adj,
429 });
430 }
431 }
432
433 Ok(gadgets)
434 }
435
436 pub fn find_pop_ret(&self, module_name: &str) -> Result<Vec<RetGadget>> {
438 let query = ModuleQuery::new(&self.peb);
439 let module = query.find_by_name(module_name)?;
440
441 let base = module.base();
442 let size = module.size();
443 let name = module.name();
444 let is_system = is_system_module(&name);
445
446 let data = unsafe { std::slice::from_raw_parts(base as *const u8, size) };
448
449 let mut gadgets = Vec::new();
450
451 for offset in 0..=(size.saturating_sub(2)) {
460 let first = data[offset];
461 if (0x58..=0x5F).contains(&first) && first != 0x5C && data[offset + 1] == 0xC3 {
462 gadgets.push(RetGadget {
463 gadget: Gadget {
464 address: base + offset,
465 gadget_type: GadgetType::PopRet {
466 register: first - 0x58,
467 },
468 module_name: name.clone(),
469 module_offset: offset,
470 is_system_module: is_system,
471 },
472 stack_adjustment: 8, });
474 }
475 }
476
477 for offset in 0..=(size.saturating_sub(3)) {
482 if data[offset] == 0x41
483 && (0x58..=0x5F).contains(&data[offset + 1])
484 && data[offset + 1] != 0x5C
485 && data[offset + 2] == 0xC3
486 {
487 gadgets.push(RetGadget {
488 gadget: Gadget {
489 address: base + offset,
490 gadget_type: GadgetType::PopRet {
491 register: data[offset + 1] - 0x58 + 8,
492 },
493 module_name: name.clone(),
494 module_offset: offset,
495 is_system_module: is_system,
496 },
497 stack_adjustment: 8,
498 });
499 }
500 }
501
502 Ok(gadgets)
503 }
504
505 pub fn scan_module_all(&self, module_name: &str) -> Result<Vec<Gadget>> {
507 let mut all_gadgets = Vec::new();
508
509 for gadget_type in [
511 GadgetType::JmpRbx,
512 GadgetType::JmpRax,
513 GadgetType::JmpRcx,
514 GadgetType::JmpRdx,
515 GadgetType::CallRbx,
516 GadgetType::CallRax,
517 GadgetType::Ret,
518 GadgetType::PushRbxRet,
519 ] {
520 if let Ok(gadgets) = self.find_gadgets_of_type(module_name, gadget_type) {
521 all_gadgets.extend(gadgets);
522 }
523 }
524
525 if let Ok(ret_gadgets) = self.find_add_rsp_ret(module_name) {
527 all_gadgets.extend(ret_gadgets.into_iter().map(|r| r.gadget));
528 }
529
530 if let Ok(pop_gadgets) = self.find_pop_ret(module_name) {
532 all_gadgets.extend(pop_gadgets.into_iter().map(|r| r.gadget));
533 }
534
535 Ok(all_gadgets)
536 }
537
538 pub fn find_best_jmp_gadget(&self) -> Result<JmpGadget> {
541 if let Ok(gadgets) = self.find_jmp_rbx("ntdll.dll") {
543 if let Some(g) = gadgets.into_iter().next() {
544 return Ok(g);
545 }
546 }
547
548 if let Ok(gadgets) = self.find_jmp_rax("ntdll.dll") {
549 if let Some(g) = gadgets.into_iter().next() {
550 return Ok(g);
551 }
552 }
553
554 if let Ok(gadgets) = self.find_jmp_rbx("kernelbase.dll") {
556 if let Some(g) = gadgets.into_iter().next() {
557 return Ok(g);
558 }
559 }
560
561 if let Ok(gadgets) = self.find_jmp_rbx("kernel32.dll") {
563 if let Some(g) = gadgets.into_iter().next() {
564 return Ok(g);
565 }
566 }
567
568 Err(WraithError::SyscallEnumerationFailed {
569 reason: "no suitable jmp gadget found".into(),
570 })
571 }
572
573 pub fn find_best_ret_gadget(&self) -> Result<RetGadget> {
575 if let Ok(gadgets) = self.find_ret("kernel32.dll") {
577 if let Some(g) = gadgets.into_iter().next() {
578 return Ok(g);
579 }
580 }
581
582 if let Ok(gadgets) = self.find_ret("kernelbase.dll") {
583 if let Some(g) = gadgets.into_iter().next() {
584 return Ok(g);
585 }
586 }
587
588 if let Ok(gadgets) = self.find_ret("ntdll.dll") {
589 if let Some(g) = gadgets.into_iter().next() {
590 return Ok(g);
591 }
592 }
593
594 Err(WraithError::SyscallEnumerationFailed {
595 reason: "no suitable ret gadget found".into(),
596 })
597 }
598}
599
600fn is_system_module(name: &str) -> bool {
602 let lower = name.to_lowercase();
603 lower == "ntdll.dll"
604 || lower == "kernel32.dll"
605 || lower == "kernelbase.dll"
606 || lower == "user32.dll"
607 || lower == "gdi32.dll"
608 || lower == "advapi32.dll"
609 || lower == "msvcrt.dll"
610 || lower == "ws2_32.dll"
611 || lower == "ole32.dll"
612 || lower == "combase.dll"
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618
619 #[test]
620 fn test_find_jmp_rbx_ntdll() {
621 let finder = GadgetFinder::new().expect("should create finder");
622 let gadgets = finder.find_jmp_rbx("ntdll.dll").expect("should find gadgets");
623
624 assert!(!gadgets.is_empty(), "should find jmp rbx gadgets in ntdll");
626
627 let first = &gadgets[0];
629 assert!(first.gadget.is_valid(), "gadget should be valid");
630 assert!(first.gadget.is_system_module, "should be system module");
631 }
632
633 #[test]
634 fn test_find_ret_gadgets() {
635 let finder = GadgetFinder::new().expect("should create finder");
636 let gadgets = finder.find_ret("kernel32.dll").expect("should find gadgets");
637
638 assert!(!gadgets.is_empty(), "should find ret gadgets in kernel32");
640
641 let first = &gadgets[0];
643 assert!(first.gadget.is_valid(), "gadget should be valid");
644 }
645
646 #[test]
647 fn test_find_add_rsp_ret() {
648 let finder = GadgetFinder::new().expect("should create finder");
649
650 if let Ok(gadgets) = finder.find_add_rsp_ret("ntdll.dll") {
651 for g in gadgets.iter().take(5) {
653 assert!(g.stack_adjustment > 0, "should have stack adjustment");
654 }
655 }
656 }
657
658 #[test]
659 fn test_gadget_cache() {
660 let cache = GadgetCache::build().expect("should build cache");
661
662 assert!(cache.jmp_rbx().is_some() || cache.jmp_rax().is_some());
664 }
665
666 #[test]
667 fn test_best_jmp_gadget() {
668 let finder = GadgetFinder::new().expect("should create finder");
669 let gadget = finder.find_best_jmp_gadget().expect("should find gadget");
670
671 assert!(gadget.gadget.is_valid(), "best gadget should be valid");
672 }
673}