1use clap::Parser;
2use lazy_static::lazy_static;
3use libc::{addrinfo, c_char, c_int, sockaddr, socklen_t};
4use secexit_common::{SecurityPolicy, load_policy};
5use std::ffi::CStr;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Mutex;
8
9#[derive(Parser, Debug)]
10#[command(author, version, about = "secexit shim", long_about = None)]
11struct Args {
12 #[arg(
13 short,
14 long,
15 env = "SECEXIT_POLICY",
16 default_value = "~/.config/secexit/policy.json"
17 )]
18 policy: String,
19}
20
21fn should_block_domain(hostname: &str, policy: &SecurityPolicy) -> bool {
23 if policy.lockdown_mode {
24 return true;
25 }
26 for domain in &policy.blocked_domains {
27 if hostname.contains(domain) {
28 return true;
29 }
30 }
31 false
32}
33
34fn should_block_ip(ip_str: &str, policy: &SecurityPolicy) -> bool {
36 if policy.lockdown_mode {
37 return true;
38 }
39 for blocked in &policy.blocked_ips {
40 if ip_str == blocked {
41 return true;
42 }
43 }
44 false
45}
46
47lazy_static! {
48 static ref POLICY: Mutex<SecurityPolicy> = {
49 let final_path = Args::try_parse()
50 .map(|a| a.policy)
51 .unwrap_or_else(|_| "~/.config/secexit/policy.json".to_string());
52
53 match tokio::runtime::Runtime::new() {
54 Ok(rt) => {
55 let policy = rt.block_on(load_policy(&final_path));
56 Mutex::new(policy)
57 }
58 Err(e) => {
59 eprintln!(
60 "[secexit] ERROR: Failed to create async runtime: {}. Defaulting to ALLOW.",
61 e
62 );
63 Mutex::new(SecurityPolicy::default_allow())
64 }
65 }
66 };
67 static ref REAL_CONNECT: Mutex<Option<ConnectFn>> = Mutex::new(None);
68 static ref REAL_GETADDRINFO: Mutex<Option<GetAddrInfoFn>> = Mutex::new(None);
69}
70
71type ConnectFn = unsafe extern "C" fn(c_int, *const sockaddr, socklen_t) -> c_int;
72type GetAddrInfoFn = unsafe extern "C" fn(
73 *const c_char,
74 *const c_char,
75 *const addrinfo,
76 *mut *mut addrinfo,
77) -> c_int;
78
79unsafe fn get_real_connect() -> ConnectFn {
80 let mut real = REAL_CONNECT.lock().unwrap_or_else(|e| e.into_inner());
81 if let Some(f) = *real {
82 return f;
83 }
84 let sym = c"connect";
85 let ptr = unsafe { libc::dlsym(libc::RTLD_NEXT, sym.as_ptr()) };
86 let f: ConnectFn = unsafe { std::mem::transmute(ptr) };
87 *real = Some(f);
88 f
89}
90
91unsafe fn get_real_getaddrinfo() -> GetAddrInfoFn {
92 let mut real = REAL_GETADDRINFO.lock().unwrap_or_else(|e| e.into_inner());
93 if let Some(f) = *real {
94 return f;
95 }
96 let sym = b"getaddrinfo\0";
97 let ptr = unsafe { libc::dlsym(libc::RTLD_NEXT, sym.as_ptr() as *const c_char) };
98 let f: GetAddrInfoFn = unsafe { std::mem::transmute(ptr) };
99 *real = Some(f);
100 f
101}
102
103#[unsafe(no_mangle)]
114pub unsafe extern "C" fn getaddrinfo(
115 node: *const c_char,
116 service: *const c_char,
117 hints: *const addrinfo,
118 res: *mut *mut addrinfo,
119) -> c_int {
120 if !node.is_null()
121 && let Ok(hostname) = unsafe { CStr::from_ptr(node) }.to_str()
122 {
123 let policy = POLICY.lock().unwrap_or_else(|e| e.into_inner());
124
125 if should_block_domain(hostname, &policy) {
126 if policy.lockdown_mode {
127 eprintln!("[secexit] LOCKDOWN: Blocking DNS lookup for {}", hostname);
128 } else {
129 eprintln!("[secexit] BLOCKED DOMAIN: {}", hostname);
130 }
131 return libc::EAI_FAIL;
132 }
133 }
134 unsafe { get_real_getaddrinfo()(node, service, hints, res) }
135}
136
137#[unsafe(no_mangle)]
148pub unsafe extern "C" fn connect(
149 sockfd: c_int,
150 addr: *const sockaddr,
151 addrlen: socklen_t,
152) -> c_int {
153 if let Some(sa) = unsafe { sockaddr_to_rust(addr, addrlen) }
154 && let IpAddr::V4(ipv4) = sa.ip()
155 {
156 let ip_str = ipv4.to_string();
157 let policy = POLICY.lock().unwrap_or_else(|e| e.into_inner());
158
159 if should_block_ip(&ip_str, &policy) {
160 eprintln!("[secexit] BLOCKED IP: {}", ip_str);
161 unsafe { *libc::__errno_location() = libc::EACCES };
162 return -1;
163 }
164 }
165 unsafe { get_real_connect()(sockfd, addr, addrlen) }
166}
167
168unsafe fn sockaddr_to_rust(addr: *const sockaddr, _len: socklen_t) -> Option<SocketAddr> {
169 if addr.is_null() {
170 return None;
171 }
172 let family = unsafe { (*addr).sa_family as i32 };
173 if family == libc::AF_INET {
174 let sin = unsafe { &*(addr as *const libc::sockaddr_in) };
175 let ip = std::net::Ipv4Addr::from(u32::from_be(sin.sin_addr.s_addr));
176 let port = u16::from_be(sin.sin_port);
177 return Some(SocketAddr::new(IpAddr::V4(ip), port));
178 }
179 None
180}
181
182#[cfg(test)]
183mod tests;