1use core_foundation::{
6 base::{TCFType, ToVoid},
7 runloop::CFRunLoop,
8 string::{CFString, CFStringRef},
9};
10use system_configuration_sys::{
11 libc,
12 network_reachability::{
13 kSCNetworkReachabilityFlagsConnectionOnDemand,
14 kSCNetworkReachabilityFlagsConnectionOnTraffic,
15 kSCNetworkReachabilityFlagsConnectionRequired,
16 kSCNetworkReachabilityFlagsInterventionRequired, kSCNetworkReachabilityFlagsIsDirect,
17 kSCNetworkReachabilityFlagsIsLocalAddress, kSCNetworkReachabilityFlagsIsWWAN,
18 kSCNetworkReachabilityFlagsReachable, kSCNetworkReachabilityFlagsTransientConnection,
19 SCNetworkReachabilityContext, SCNetworkReachabilityCreateWithAddress,
20 SCNetworkReachabilityCreateWithAddressPair, SCNetworkReachabilityCreateWithName,
21 SCNetworkReachabilityFlags, SCNetworkReachabilityGetFlags, SCNetworkReachabilityGetTypeID,
22 SCNetworkReachabilityRef, SCNetworkReachabilityScheduleWithRunLoop,
23 SCNetworkReachabilitySetCallback, SCNetworkReachabilityUnscheduleFromRunLoop,
24 },
25};
26
27use std::{
28 error::Error,
29 ffi::{c_void, CStr},
30 fmt::{self, Display},
31 net::SocketAddr,
32 ptr,
33 sync::Arc,
34};
35
36#[derive(Debug)]
38pub enum ReachabilityError {
39 FailedToDetermineReachability,
41 UnrecognizedFlags(u32),
43}
44
45#[derive(Debug)]
47pub struct SchedulingError(());
48
49impl Display for SchedulingError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "Failed to schedule a reachability callback on a runloop")
52 }
53}
54
55impl Error for SchedulingError {}
56
57#[derive(Debug)]
59pub struct UnschedulingError(());
60
61impl Display for UnschedulingError {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 write!(
64 f,
65 "Failed to unschedule a reachability callback on a runloop"
66 )
67 }
68}
69
70impl Error for UnschedulingError {}
71
72#[derive(Debug)]
74pub struct SetCallbackError {}
75
76impl Display for SetCallbackError {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "Failed to set a callback for reachability")
79 }
80}
81
82impl Error for SetCallbackError {}
83
84bitflags::bitflags! {
85 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89 pub struct ReachabilityFlags: u32 {
90 const TRANSIENT_CONNECTION = kSCNetworkReachabilityFlagsTransientConnection;
93 const REACHABLE = kSCNetworkReachabilityFlagsReachable;
96 const CONNECTION_REQUIRED = kSCNetworkReachabilityFlagsConnectionRequired;
102 const CONNECTION_ON_TRAFFIC = kSCNetworkReachabilityFlagsConnectionOnTraffic;
106 const INTERVENTION_REQUIRED = kSCNetworkReachabilityFlagsInterventionRequired;
109 const CONNECTION_ON_DEMAND = kSCNetworkReachabilityFlagsConnectionOnDemand;
112 const IS_LOCAL_ADDRESS = kSCNetworkReachabilityFlagsIsLocalAddress;
114 const IS_DIRECT = kSCNetworkReachabilityFlagsIsDirect;
117 const IS_WWAN = kSCNetworkReachabilityFlagsIsWWAN;
119 }
120}
121
122core_foundation::declare_TCFType!(
123 SCNetworkReachability,
129 SCNetworkReachabilityRef
130);
131
132core_foundation::impl_TCFType!(
133 SCNetworkReachability,
134 SCNetworkReachabilityRef,
135 SCNetworkReachabilityGetTypeID
136);
137
138impl SCNetworkReachability {
139 pub fn from_addr_pair(local: SocketAddr, remote: SocketAddr) -> SCNetworkReachability {
145 let ptr = unsafe {
146 SCNetworkReachabilityCreateWithAddressPair(
147 std::ptr::null(),
148 &*to_c_sockaddr(local),
149 &*to_c_sockaddr(remote),
150 )
151 };
152
153 unsafe { Self::wrap_under_create_rule(ptr) }
154 }
155
156 pub fn from_host(host: &CStr) -> Option<Self> {
162 let ptr = unsafe { SCNetworkReachabilityCreateWithName(ptr::null(), host.as_ptr()) };
163 if ptr.is_null() {
164 None
165 } else {
166 unsafe { Some(Self::wrap_under_create_rule(ptr)) }
167 }
168 }
169
170 pub fn reachability(&self) -> Result<ReachabilityFlags, ReachabilityError> {
176 let mut raw_flags = 0u32;
177 if unsafe { SCNetworkReachabilityGetFlags(self.0, &mut raw_flags) } == 0u8 {
178 return Err(ReachabilityError::FailedToDetermineReachability);
179 }
180
181 ReachabilityFlags::from_bits(raw_flags)
182 .ok_or(ReachabilityError::UnrecognizedFlags(raw_flags))
183 }
184
185 pub unsafe fn schedule_with_runloop(
196 &self,
197 run_loop: &CFRunLoop,
198 run_loop_mode: CFStringRef,
199 ) -> Result<(), SchedulingError> {
200 if SCNetworkReachabilityScheduleWithRunLoop(
201 self.0,
202 run_loop.to_void() as *mut _,
203 run_loop_mode,
204 ) == 0u8
205 {
206 Err(SchedulingError(()))
207 } else {
208 Ok(())
209 }
210 }
211
212 pub unsafe fn unschedule_from_runloop(
223 &self,
224 run_loop: &CFRunLoop,
225 run_loop_mode: CFStringRef,
226 ) -> Result<(), UnschedulingError> {
227 if SCNetworkReachabilityUnscheduleFromRunLoop(
228 self.0,
229 run_loop.to_void() as *mut _,
230 run_loop_mode,
231 ) == 0u8
232 {
233 Err(UnschedulingError(()))
234 } else {
235 Ok(())
236 }
237 }
238
239 pub fn set_callback<F: Fn(ReachabilityFlags) + Sync + Send>(
247 &mut self,
248 callback: F,
249 ) -> Result<(), SetCallbackError> {
250 let callback = Arc::new(NetworkReachabilityCallbackContext::new(
251 self.clone(),
252 callback,
253 ));
254
255 let mut callback_context = SCNetworkReachabilityContext {
256 version: 0,
257 info: Arc::into_raw(callback) as *mut _,
258 retain: Some(NetworkReachabilityCallbackContext::<F>::retain_context),
259 release: Some(NetworkReachabilityCallbackContext::<F>::release_context),
260 copyDescription: Some(NetworkReachabilityCallbackContext::<F>::copy_ctx_description),
261 };
262
263 let result = unsafe {
264 SCNetworkReachabilitySetCallback(
265 self.0,
266 Some(NetworkReachabilityCallbackContext::<F>::callback),
267 &mut callback_context,
268 )
269 };
270
271 unsafe { Arc::decrement_strong_count(callback_context.info) };
281
282 if result == 0u8 {
283 Err(SetCallbackError {})
284 } else {
285 Ok(())
286 }
287 }
288}
289
290impl From<SocketAddr> for SCNetworkReachability {
291 fn from(addr: SocketAddr) -> Self {
292 unsafe {
293 let ptr =
294 SCNetworkReachabilityCreateWithAddress(std::ptr::null(), &*to_c_sockaddr(addr));
295 SCNetworkReachability::wrap_under_create_rule(ptr)
296 }
297 }
298}
299
300struct NetworkReachabilityCallbackContext<T: Fn(ReachabilityFlags) + Sync + Send> {
301 _host: SCNetworkReachability,
302 callback: T,
303}
304
305impl<T: Fn(ReachabilityFlags) + Sync + Send> NetworkReachabilityCallbackContext<T> {
306 fn new(host: SCNetworkReachability, callback: T) -> Self {
307 Self {
308 _host: host,
309 callback,
310 }
311 }
312
313 extern "C" fn callback(
314 _target: SCNetworkReachabilityRef,
315 flags: SCNetworkReachabilityFlags,
316 context: *mut c_void,
317 ) {
318 let context: &mut Self = unsafe { &mut (*(context as *mut _)) };
319 (context.callback)(ReachabilityFlags::from_bits_retain(flags));
320 }
321
322 extern "C" fn copy_ctx_description(_ctx: *const c_void) -> CFStringRef {
323 let description = CFString::from_static_string("NetworkRechability's callback context");
324 let description_ref = description.as_concrete_TypeRef();
325 std::mem::forget(description);
326 description_ref
327 }
328
329 extern "C" fn release_context(ctx: *const c_void) {
330 unsafe {
331 Arc::decrement_strong_count(ctx as *mut Self);
332 }
333 }
334
335 extern "C" fn retain_context(ctx_ptr: *const c_void) -> *const c_void {
336 unsafe {
337 Arc::increment_strong_count(ctx_ptr as *mut Self);
338 }
339 ctx_ptr
340 }
341}
342
343fn to_c_sockaddr(addr: SocketAddr) -> Box<libc::sockaddr> {
346 let ptr = match addr {
347 SocketAddr::V4(addr) => Box::into_raw(Box::new(libc::sockaddr_in {
351 sin_len: std::mem::size_of::<libc::sockaddr_in>() as u8,
352 sin_family: libc::AF_INET as libc::sa_family_t,
353 sin_port: addr.port().to_be(),
354 sin_addr: {
355 libc::in_addr {
359 s_addr: u32::from_ne_bytes(addr.ip().octets()),
360 }
361 },
362 sin_zero: Default::default(),
363 })) as *mut c_void,
364 SocketAddr::V6(addr) => Box::into_raw(Box::new(libc::sockaddr_in6 {
368 sin6_len: std::mem::size_of::<libc::sockaddr_in6>() as u8,
369 sin6_family: libc::AF_INET6 as libc::sa_family_t,
370 sin6_port: addr.port().to_be(),
371 sin6_flowinfo: addr.flowinfo(),
372 sin6_addr: libc::in6_addr {
373 s6_addr: addr.ip().octets(),
374 },
375 sin6_scope_id: addr.scope_id(),
376 })) as *mut c_void,
377 };
378
379 unsafe { Box::from_raw(ptr as *mut _) }
380}
381
382#[cfg(test)]
383mod test {
384 use super::*;
385
386 use core_foundation::runloop::{kCFRunLoopCommonModes, CFRunLoop};
387 use std::{
388 ffi::CString,
389 net::{Ipv4Addr, Ipv6Addr},
390 };
391
392 #[test]
393 fn test_network_reachability_from_addr() {
394 let sockaddrs = vec![
395 "0.0.0.0:0".parse::<SocketAddr>().unwrap(),
396 "[::0]:0".parse::<SocketAddr>().unwrap(),
397 ];
398
399 for addr in sockaddrs {
400 let mut reachability = SCNetworkReachability::from(addr);
401 assert!(
402 !reachability.0.is_null(),
403 "Failed to construct a SCNetworkReachability struct with {}",
404 addr
405 );
406 reachability.set_callback(|_| {}).unwrap();
407 unsafe {
409 reachability
410 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
411 .unwrap();
412 reachability
413 .unschedule_from_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
414 .unwrap();
415 }
416 }
417 }
418
419 #[test]
420 fn test_sockaddr_pair_reachability() {
421 let pairs = vec![
422 ("0.0.0.0:0", "[::0]:0"),
423 ("[::0]:0", "0.0.0.0:0"),
424 ("[::0]:0", "[::0]:0"),
425 ("0.0.0.0:0", "0.0.0.0:0"),
426 ]
427 .into_iter()
428 .map(|(a, b)| (a.parse().unwrap(), b.parse().unwrap()));
429
430 for (local, remote) in pairs {
431 let mut reachability = SCNetworkReachability::from_addr_pair(local, remote);
432 assert!(
433 !reachability.0.is_null(),
434 "Failed to construct a SCNetworkReachability struct with address pair {} - {}",
435 local,
436 remote
437 );
438 reachability.set_callback(|_| {}).unwrap();
439 unsafe {
441 reachability
442 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
443 .unwrap();
444 reachability
445 .unschedule_from_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
446 .unwrap();
447 }
448 }
449 }
450
451 #[test]
452 fn test_sockaddr_local_to_dns_google_pair_reachability() {
453 let sockaddrs = [
454 "[2001:4860:4860::8844]:443".parse::<SocketAddr>().unwrap(),
455 "8.8.4.4:443".parse().unwrap(),
456 ];
457 for remote_addr in sockaddrs {
458 match std::net::TcpStream::connect(remote_addr) {
459 Err(_) => {
460 let local_addr = if remote_addr.is_ipv4() {
461 SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
462 } else {
463 SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0)
464 };
465 let reachability =
466 SCNetworkReachability::from_addr_pair(local_addr, remote_addr);
467 let reachability_flags = reachability.reachability().unwrap();
468 assert!(!reachability_flags.contains(ReachabilityFlags::REACHABLE));
470 }
471 Ok(tcp) => {
472 let local = tcp.local_addr().unwrap();
473 let remote = tcp.peer_addr().unwrap();
474 let reachability = SCNetworkReachability::from_addr_pair(local, remote);
475 let reachability_flags = reachability.reachability().unwrap();
476 assert!(reachability_flags.contains(ReachabilityFlags::REACHABLE));
478 }
479 }
480 }
481 }
482
483 #[test]
484 fn test_reachability_ref_from_host() {
485 let valid_inputs = vec!["example.com", "host-in-local-network", "en0"];
486
487 let get_cstring = |input: &str| CString::new(input).unwrap();
488
489 for input in valid_inputs.into_iter().map(get_cstring) {
490 match SCNetworkReachability::from_host(&input) {
491 Some(mut reachability) => {
492 reachability.set_callback(|_| {}).unwrap();
493 unsafe {
495 reachability
496 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
497 .unwrap();
498 reachability
499 .unschedule_from_runloop(
500 &CFRunLoop::get_current(),
501 kCFRunLoopCommonModes,
502 )
503 .unwrap();
504 }
505 }
506 None => {
507 panic!(
508 "Failed to construct a SCNetworkReachability from {}",
509 input.to_string_lossy(),
510 );
511 }
512 }
513 }
514
515 assert!(
517 SCNetworkReachability::from_host(&get_cstring("")).is_none(),
518 "Constructed valid SCNetworkReachability from empty string"
519 );
520 }
521
522 unsafe impl Send for SCNetworkReachability {}
523
524 #[test]
525 fn assert_infallibility_of_setting_a_callback() {
526 let (tx, rx) = std::sync::mpsc::channel();
527 std::thread::spawn(move || {
528 let mut reachability =
529 SCNetworkReachability::from("0.0.0.0:0".parse::<SocketAddr>().unwrap());
530 reachability.set_callback(|_| {}).unwrap();
531 unsafe {
533 reachability
534 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
535 .unwrap();
536 }
537 reachability.set_callback(|_| {}).unwrap();
538 let _ = tx.send(reachability);
539 CFRunLoop::run_current();
540 });
541 let mut reachability = rx.recv().unwrap();
542 std::thread::sleep(std::time::Duration::from_secs(1));
543 reachability.set_callback(|_| {}).unwrap();
544 }
545}