1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
#![cfg_attr(not(feature = "std"), no_std)]

use core::{
    cell::UnsafeCell,
    marker::PhantomData,
    mem::MaybeUninit,
    ops::{Deref, DerefMut},
    sync::atomic::AtomicBool,
};

use hooker::gen_hook_info;
use thiserror_no_std::Error;
use windows_sys::{
    core::PCSTR,
    Win32::{
        Foundation::{FreeLibrary, GetLastError, HMODULE},
        System::{
            Diagnostics::Debug::WriteProcessMemory,
            LibraryLoader::{GetProcAddress, LoadLibraryA},
            Memory::{
                VirtualAlloc, VirtualFree, VirtualProtect, MEM_COMMIT, MEM_RELEASE, PAGE_EXECUTE,
                PAGE_PROTECTION_FLAGS, PAGE_READWRITE,
            },
            ProcessStatus::{GetModuleInformation, MODULEINFO},
            Threading::GetCurrentProcess,
        },
    },
};

/// a lock which just fails when trying to lock it while it is already locked.
struct SingleLock<T> {
    value: UnsafeCell<T>,
    is_locked: AtomicBool,
}
impl<T> SingleLock<T> {
    /// creates a new lock with the given value
    const fn new(value: T) -> Self {
        Self {
            value: UnsafeCell::new(value),
            is_locked: AtomicBool::new(false),
        }
    }
    /// locks the lock and returns a guard. if the lock is already locked, returns `None`.
    fn lock(&self) -> Option<SingleLockGuard<T>> {
        if self
            .is_locked
            .swap(true, core::sync::atomic::Ordering::AcqRel)
        {
            return None;
        }
        Some(SingleLockGuard { lock: self })
    }
}
unsafe impl<T> Send for SingleLock<T> {}
unsafe impl<T> Sync for SingleLock<T> {}

/// a single lock guard which unlocks the lock when dropped.
struct SingleLockGuard<'a, T> {
    lock: &'a SingleLock<T>,
}
impl<'a, T> Drop for SingleLockGuard<'a, T> {
    fn drop(&mut self) {
        self.lock
            .is_locked
            .store(false, core::sync::atomic::Ordering::Release)
    }
}
impl<'a, T> Deref for SingleLockGuard<'a, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        unsafe { &*self.lock.value.get() }
    }
}
impl<'a, T> DerefMut for SingleLockGuard<'a, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.lock.value.get() }
    }
}

/// a static hook which can be used to store hook information as a static variable so that it can easily be accessed from anywhere.
/// the generic argument `F` should be a function pointer signature of the hooked function (e.g `extern "C" fn(i32) -> i32`).
pub struct StaticHook<F: Copy> {
    hook: SingleLock<Option<Hook>>,
    phantom: PhantomData<F>,
}
impl<F: Copy> StaticHook<F> {
    const HOOK_USED_MULTIPLE_TIMES_ERR_MSG: &'static str = "static hook used multiple times";

    /// creates a new, empty, static hook.
    pub const fn new() -> Self {
        Self {
            hook: SingleLock::new(None),
            phantom: PhantomData,
        }
    }
    /// locks the hook and returns a lock guard for it.
    ///
    /// # Panics
    ///
    /// panics if the lock is already held
    fn lock_hook(&self) -> SingleLockGuard<Option<Hook>> {
        self.hook
            .lock()
            .expect(Self::HOOK_USED_MULTIPLE_TIMES_ERR_MSG)
    }

    /// locks the hook, makes sure that it is empty, and returns a lock guard for it.
    ///
    /// # Panics
    ///
    /// panics if the lock is already held or if the hook is not empty
    fn lock_hook_and_assert_empty(&self) -> SingleLockGuard<Option<Hook>> {
        let hook = self.lock_hook();
        if hook.is_some() {
            panic!("{}", Self::HOOK_USED_MULTIPLE_TIMES_ERR_MSG)
        }
        hook
    }

    /// hooks the function with the given `fn_addr` from the given `module` such that when the function is called it instead jumps
    /// to the given `hook_to_addr`.
    ///
    /// # Safety
    ///
    /// the signature of the provided function must match the signature of this static hook.
    ///
    /// # Panics
    ///
    /// panics if this static hook was already used to hook some function.
    pub unsafe fn hook_function(
        &self,
        module: HMODULE,
        fn_addr: usize,
        hook_to_addr: usize,
    ) -> Result<()> {
        let mut hook = self.lock_hook_and_assert_empty();
        let created_hook = hook_function(module, fn_addr, hook_to_addr)?;
        *hook = Some(created_hook);
        Ok(())
    }

    /// hooks the function with the `fn_name` from the library with the provided `library_name` such that when the function is called it instead jumps
    /// to the given `hook_to_addr`.
    ///
    /// # Safety
    ///
    /// the signature of the provided function must match the signature of this static hook.
    ///
    /// # Panics
    ///
    /// panics if this static hook was already used to hook some function.
    pub fn hook_function_by_name(
        &self,
        library_name: PCSTR,
        fn_name: PCSTR,
        hook_to_addr: usize,
    ) -> Result<()> {
        let mut hook = self.lock_hook_and_assert_empty();
        let created_hook = hook_function_by_name(library_name, fn_name, hook_to_addr)?;
        *hook = Some(created_hook);
        Ok(())
    }

    /// returns a reference to the hook.
    ///
    /// # Panics
    ///
    /// panics if the static hook was not yet used to hook any function.
    pub fn get_hook(&self) -> &Hook {
        let hook_guard = self.lock_hook();
        let hook_opt = unsafe { &*(hook_guard.deref() as *const Option<Hook>) };
        hook_opt
            .as_ref()
            .expect("static hook used before hooking any function")
    }

    /// provides an interface for calling a function which will simulate the original function behaviour without the hook.
    ///
    /// # Panics
    ///
    /// panics if the static hook was not yet used to hook any function.
    pub fn original(&self) -> F {
        let hook = self.get_hook();
        unsafe { hook.original() }
    }
}

/// a guard which calls `FreeLibrary` on the module handle when dropped.
struct ModuleHandleGuard(HMODULE);
impl Drop for ModuleHandleGuard {
    fn drop(&mut self) {
        let _ = unsafe { FreeLibrary(self.0) };
    }
}

/// hooks the function with the `fn_name` from the library with the provided `library_name` such that when the function is called it instead jumps
/// to the given `hook_to_addr`.
pub fn hook_function_by_name(
    library_name: PCSTR,
    fn_name: PCSTR,
    hook_to_addr: usize,
) -> Result<Hook> {
    let load_library_res = unsafe { LoadLibraryA(library_name) };
    if load_library_res == 0 {
        return Err(Error::FailedToLoadLibrary(WinapiError::last_error()));
    }
    let module_guard = ModuleHandleGuard(load_library_res);
    let fn_addr =
        unsafe { GetProcAddress(module_guard.0, fn_name).ok_or(Error::NoFunctionWithThatName)? };
    hook_function(module_guard.0, fn_addr as usize, hook_to_addr)
}

/// hooks the function with the given `fn_addr` from the given `module` such that when the function is called it instead jumps
/// to the given `hook_to_addr`.
pub fn hook_function(module: HMODULE, fn_addr: usize, hook_to_addr: usize) -> Result<Hook> {
    let mut module_info_uninit: MaybeUninit<MODULEINFO> = MaybeUninit::uninit();
    let res = unsafe {
        GetModuleInformation(
            GetCurrentProcess(),
            module,
            module_info_uninit.as_mut_ptr(),
            core::mem::size_of::<MODULEINFO>() as u32,
        )
    };
    if res == 0 {
        return Err(Error::FailedToGetModuleInformation(
            WinapiError::last_error(),
        ));
    }
    let module_info = unsafe { module_info_uninit.assume_init() };
    let module_end_addr = module_info.lpBaseOfDll as usize + module_info.SizeOfImage as usize;
    let fn_max_possible_size = module_end_addr - fn_addr;
    let fn_possible_content =
        unsafe { core::slice::from_raw_parts(fn_addr as *const u8, fn_max_possible_size) };
    let hook_info = gen_hook_info(fn_possible_content, fn_addr as u64, hook_to_addr as u64)?;

    // allocate the trampoline and copy its code
    let mut trampiline_alloc = Allocation::new(hook_info.trampoline_size());
    let trampoline_code = hook_info.build_trampoline(trampiline_alloc.ptr as u64);
    let trampoline_alloc_slice = unsafe { trampiline_alloc.as_mut_slice() };
    trampoline_alloc_slice[..trampoline_code.len()].copy_from_slice(&trampoline_code);

    // done writing the trampoline, now make it executable
    trampiline_alloc.make_executable_and_read_only();

    // write the jumper
    let jumper_code = hook_info.jumper();
    let mut bytes_written = 0;
    let res = unsafe {
        WriteProcessMemory(
            GetCurrentProcess(),
            fn_addr as *const _,
            jumper_code.as_ptr().cast(),
            jumper_code.len(),
            &mut bytes_written,
        )
    };
    assert_ne!(res, 0, "failed to write jumper");
    // make sure that all bytes were written
    assert_eq!(
        bytes_written,
        jumper_code.len(),
        "not all bytes of jumper were written to the start of the function"
    );

    Ok(Hook {
        trampoline: trampiline_alloc,
        fn_addr,
        hook_to_addr,
    })
}

/// a memory allocation
pub struct Allocation {
    ptr: *mut u8,
    size: usize,
}
impl Allocation {
    fn new(size: usize) -> Self {
        let ptr = unsafe { VirtualAlloc(core::ptr::null(), size, MEM_COMMIT, PAGE_READWRITE) };
        if ptr.is_null() {
            // should never happen except for OOM, in which case the default behaviour is to panic anyways.
            panic!("failed to allocate read-write memory using VirtualAlloc");
        }
        Self {
            ptr: ptr.cast(),
            size,
        }
    }
    fn make_executable_and_read_only(&mut self) {
        let mut old_prot: MaybeUninit<PAGE_PROTECTION_FLAGS> = MaybeUninit::uninit();
        let res = unsafe {
            VirtualProtect(
                self.ptr.cast(),
                self.size,
                PAGE_EXECUTE,
                old_prot.as_mut_ptr(),
            )
        };
        assert_ne!(res, 0, "failed to change memory protection to executable");
    }
    /// # Safety
    /// must be called only if the memory still has write permissions
    unsafe fn as_mut_slice(&mut self) -> &mut [u8] {
        unsafe { core::slice::from_raw_parts_mut(self.ptr, self.size) }
    }
    /// returns a pointer to the allocation
    pub fn as_ptr(&self) -> *const u8 {
        self.ptr
    }
    /// returns a mutable pointer to the allocation
    pub fn as_mut_ptr(&mut self) -> *mut u8 {
        self.ptr
    }
    /// returns the size of the allocation
    pub fn size(&self) -> usize {
        self.size
    }
}
impl Drop for Allocation {
    fn drop(&mut self) {
        unsafe {
            let _ = VirtualFree(self.ptr.cast(), 0, MEM_RELEASE);
        }
    }
}
/// a hook that was placed on some function
pub struct Hook {
    trampoline: Allocation,
    fn_addr: usize,
    hook_to_addr: usize,
}
impl Hook {
    /// returns an address of a function which when called will simulate the original function behaviour without the hook.
    pub fn original_addr(&self) -> usize {
        self.trampoline.ptr as usize
    }
    /// provides an interface for calling a function which will simulate the original function behaviour without the hook.
    /// the generic argument `F` should be a function pointer signature of the original function (e.g `extern "C" fn(i32) -> i32`).
    ///
    /// # Safety
    ///
    /// the generic argument `F` must be a function pointer, and must have the correct signature of the original function.
    pub unsafe fn original<F: Copy>(&self) -> F {
        // make sure that the provided fn signature indeed looks like a function pointer
        assert!(
            core::mem::size_of::<F>() == core::mem::size_of::<usize>()
                && core::mem::align_of::<F>() == core::mem::align_of::<usize>(),
            "provided function signature type {} is not a function pointer",
            core::any::type_name::<F>()
        );
        let trampoline_ptr = self.trampoline.ptr;
        core::mem::transmute_copy(&trampoline_ptr)
    }
    /// returns the address of the hooked function
    pub fn fn_addr(&self) -> usize {
        self.fn_addr
    }
    /// returns the address that the function was hooked to
    pub fn hook_to_addr(&self) -> usize {
        self.hook_to_addr
    }
    /// returns a reference to the hook's trampoline
    pub fn trampoline(&self) -> &Allocation {
        &self.trampoline
    }
    /// returns the hook's trampoline
    pub fn into_trampoline(self) -> Allocation {
        self.trampoline
    }
}

/// a winapi error
#[derive(Debug, Error)]
#[error("winapi error code 0x{0:x}")]
pub struct WinapiError(pub u32);
impl WinapiError {
    /// returns the last error that occured.
    pub fn last_error() -> Self {
        Self(unsafe { GetLastError() })
    }
}

/// an error that occurs while hooking a function
#[derive(Debug, Error)]
pub enum Error {
    #[error("failed to get module information")]
    FailedToGetModuleInformation(#[source] WinapiError),

    #[error("failed to load library")]
    FailedToLoadLibrary(#[source] WinapiError),

    #[error("no function with the provided name exists in the specified library")]
    NoFunctionWithThatName,

    #[error("failed to generate hook info")]
    FailedToGenHookInfo(
        #[source]
        #[from]
        hooker::HookError,
    ),
}

/// the result of hooking a function
pub type Result<T> = core::result::Result<T, Error>;