Skip to main content

secexit_shim/
lib.rs

1use clap::Parser;
2use lazy_static::lazy_static;
3use libc::{addrinfo, c_char, c_int, sockaddr, socklen_t};
4use secexit_common::{SecurityPolicy, load_policy};
5use std::ffi::CStr;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Mutex;
8
9#[derive(Parser, Debug)]
10#[command(author, version, about = "secexit shim", long_about = None)]
11struct Args {
12    #[arg(
13        short,
14        long,
15        env = "SECEXIT_POLICY",
16        default_value = "~/.config/secexit/policy.json"
17    )]
18    policy: String,
19}
20
21// return TRUE if domain should be BLOCKED
22fn should_block_domain(hostname: &str, policy: &SecurityPolicy) -> bool {
23    if policy.lockdown_mode {
24        return true;
25    }
26    for domain in &policy.blocked_domains {
27        if hostname.contains(domain) {
28            return true;
29        }
30    }
31    false
32}
33
34// return TRUE if IP should be BLOCKED
35fn should_block_ip(ip_str: &str, policy: &SecurityPolicy) -> bool {
36    if policy.lockdown_mode {
37        return true;
38    }
39    for blocked in &policy.blocked_ips {
40        if ip_str == blocked {
41            return true;
42        }
43    }
44    false
45}
46
47lazy_static! {
48    static ref POLICY: Mutex<SecurityPolicy> = {
49        let final_path = Args::try_parse()
50            .map(|a| a.policy)
51            .unwrap_or_else(|_| "~/.config/secexit/policy.json".to_string());
52
53        match tokio::runtime::Runtime::new() {
54            Ok(rt) => {
55                let policy = rt.block_on(load_policy(&final_path));
56                Mutex::new(policy)
57            }
58            Err(e) => {
59                eprintln!(
60                    "[secexit] ERROR: Failed to create async runtime: {}. Defaulting to ALLOW.",
61                    e
62                );
63                Mutex::new(SecurityPolicy::default_allow())
64            }
65        }
66    };
67    static ref REAL_CONNECT: Mutex<Option<ConnectFn>> = Mutex::new(None);
68    static ref REAL_GETADDRINFO: Mutex<Option<GetAddrInfoFn>> = Mutex::new(None);
69}
70
71type ConnectFn = unsafe extern "C" fn(c_int, *const sockaddr, socklen_t) -> c_int;
72type GetAddrInfoFn = unsafe extern "C" fn(
73    *const c_char,
74    *const c_char,
75    *const addrinfo,
76    *mut *mut addrinfo,
77) -> c_int;
78
79unsafe fn get_real_connect() -> ConnectFn {
80    let mut real = REAL_CONNECT.lock().unwrap_or_else(|e| e.into_inner());
81    if let Some(f) = *real {
82        return f;
83    }
84    let sym = c"connect";
85    let ptr = unsafe { libc::dlsym(libc::RTLD_NEXT, sym.as_ptr()) };
86    let f: ConnectFn = unsafe { std::mem::transmute(ptr) };
87    *real = Some(f);
88    f
89}
90
91unsafe fn get_real_getaddrinfo() -> GetAddrInfoFn {
92    let mut real = REAL_GETADDRINFO.lock().unwrap_or_else(|e| e.into_inner());
93    if let Some(f) = *real {
94        return f;
95    }
96    let sym = b"getaddrinfo\0";
97    let ptr = unsafe { libc::dlsym(libc::RTLD_NEXT, sym.as_ptr() as *const c_char) };
98    let f: GetAddrInfoFn = unsafe { std::mem::transmute(ptr) };
99    *real = Some(f);
100    f
101}
102
103/// hook for standard libc `getaddrinfo` function.
104///
105/// # Safety
106///
107/// This function is unsafe because it operates on raw C pointers.
108/// The caller must ensure that:
109/// * `node` and `service` are valid, null-terminated C strings (if provided).
110/// * `hints` points to a valid `addrinfo` struct (if provided).
111/// * `res` is a valid pointer to a pointer where the result will be stored.
112/// * This function is intended to be called by the C runtime (libc).
113#[unsafe(no_mangle)]
114pub unsafe extern "C" fn getaddrinfo(
115    node: *const c_char,
116    service: *const c_char,
117    hints: *const addrinfo,
118    res: *mut *mut addrinfo,
119) -> c_int {
120    if !node.is_null()
121        && let Ok(hostname) = unsafe { CStr::from_ptr(node) }.to_str()
122    {
123        let policy = POLICY.lock().unwrap_or_else(|e| e.into_inner());
124
125        if should_block_domain(hostname, &policy) {
126            if policy.lockdown_mode {
127                eprintln!("[secexit] LOCKDOWN: Blocking DNS lookup for {}", hostname);
128            } else {
129                eprintln!("[secexit] BLOCKED DOMAIN: {}", hostname);
130            }
131            return libc::EAI_FAIL;
132        }
133    }
134    unsafe { get_real_getaddrinfo()(node, service, hints, res) }
135}
136
137/// hook for standard libc `getaddrinfo` function.
138///
139/// # Safety
140///
141/// This function is unsafe because it operates on raw C pointers.
142/// The caller must ensure that:
143/// * `node` and `service` are valid, null-terminated C strings (if provided).
144/// * `hints` points to a valid `addrinfo` struct (if provided).
145/// * `res` is a valid pointer to a pointer where the result will be stored.
146/// * This function is intended to be called by the C runtime (libc).
147#[unsafe(no_mangle)]
148pub unsafe extern "C" fn connect(
149    sockfd: c_int,
150    addr: *const sockaddr,
151    addrlen: socklen_t,
152) -> c_int {
153    if let Some(sa) = unsafe { sockaddr_to_rust(addr, addrlen) }
154        && let IpAddr::V4(ipv4) = sa.ip()
155    {
156        let ip_str = ipv4.to_string();
157        let policy = POLICY.lock().unwrap_or_else(|e| e.into_inner());
158
159        if should_block_ip(&ip_str, &policy) {
160            eprintln!("[secexit] BLOCKED IP: {}", ip_str);
161            unsafe { *libc::__errno_location() = libc::EACCES };
162            return -1;
163        }
164    }
165    unsafe { get_real_connect()(sockfd, addr, addrlen) }
166}
167
168unsafe fn sockaddr_to_rust(addr: *const sockaddr, _len: socklen_t) -> Option<SocketAddr> {
169    if addr.is_null() {
170        return None;
171    }
172    let family = unsafe { (*addr).sa_family as i32 };
173    if family == libc::AF_INET {
174        let sin = unsafe { &*(addr as *const libc::sockaddr_in) };
175        let ip = std::net::Ipv4Addr::from(u32::from_be(sin.sin_addr.s_addr));
176        let port = u16::from_be(sin.sin_port);
177        return Some(SocketAddr::new(IpAddr::V4(ip), port));
178    }
179    None
180}
181
182#[cfg(test)]
183mod tests;