1use core_foundation::{
6 base::{TCFType, ToVoid},
7 runloop::CFRunLoop,
8 string::{CFString, CFStringRef},
9};
10use system_configuration_sys::{
11 libc,
12 network_reachability::{
13 kSCNetworkReachabilityFlagsConnectionOnDemand,
14 kSCNetworkReachabilityFlagsConnectionOnTraffic,
15 kSCNetworkReachabilityFlagsConnectionRequired,
16 kSCNetworkReachabilityFlagsInterventionRequired, kSCNetworkReachabilityFlagsIsDirect,
17 kSCNetworkReachabilityFlagsIsLocalAddress, kSCNetworkReachabilityFlagsIsWWAN,
18 kSCNetworkReachabilityFlagsReachable, kSCNetworkReachabilityFlagsTransientConnection,
19 SCNetworkReachabilityContext, SCNetworkReachabilityCreateWithAddress,
20 SCNetworkReachabilityCreateWithAddressPair, SCNetworkReachabilityCreateWithName,
21 SCNetworkReachabilityFlags, SCNetworkReachabilityGetFlags, SCNetworkReachabilityGetTypeID,
22 SCNetworkReachabilityRef, SCNetworkReachabilityScheduleWithRunLoop,
23 SCNetworkReachabilitySetCallback, SCNetworkReachabilityUnscheduleFromRunLoop,
24 },
25};
26
27use std::{
28 error::Error,
29 ffi::{c_void, CStr},
30 fmt::{self, Display},
31 net::SocketAddr,
32 ptr,
33 sync::Arc,
34};
35
36#[derive(Debug)]
38pub enum ReachabilityError {
39 FailedToDetermineReachability,
41 UnrecognizedFlags(u32),
43}
44
45impl Display for ReachabilityError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::FailedToDetermineReachability => write!(f, "Failed to determine reachability"),
49 Self::UnrecognizedFlags(flags) => {
50 write!(f, "Unrecognized reachability flags: {}", flags)
51 }
52 }
53 }
54}
55
56impl Error for ReachabilityError {}
57
58#[derive(Debug)]
60pub struct SchedulingError(());
61
62impl Display for SchedulingError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "Failed to schedule a reachability callback on a runloop")
65 }
66}
67
68impl Error for SchedulingError {}
69
70#[derive(Debug)]
72pub struct UnschedulingError(());
73
74impl Display for UnschedulingError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(
77 f,
78 "Failed to unschedule a reachability callback on a runloop"
79 )
80 }
81}
82
83impl Error for UnschedulingError {}
84
85#[derive(Debug)]
87pub struct SetCallbackError {}
88
89impl Display for SetCallbackError {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(f, "Failed to set a callback for reachability")
92 }
93}
94
95impl Error for SetCallbackError {}
96
97bitflags::bitflags! {
98 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102 pub struct ReachabilityFlags: u32 {
103 const TRANSIENT_CONNECTION = kSCNetworkReachabilityFlagsTransientConnection;
106 const REACHABLE = kSCNetworkReachabilityFlagsReachable;
109 const CONNECTION_REQUIRED = kSCNetworkReachabilityFlagsConnectionRequired;
115 const CONNECTION_ON_TRAFFIC = kSCNetworkReachabilityFlagsConnectionOnTraffic;
119 const INTERVENTION_REQUIRED = kSCNetworkReachabilityFlagsInterventionRequired;
122 const CONNECTION_ON_DEMAND = kSCNetworkReachabilityFlagsConnectionOnDemand;
125 const IS_LOCAL_ADDRESS = kSCNetworkReachabilityFlagsIsLocalAddress;
127 const IS_DIRECT = kSCNetworkReachabilityFlagsIsDirect;
130 const IS_WWAN = kSCNetworkReachabilityFlagsIsWWAN;
132 }
133}
134
135core_foundation::declare_TCFType!(
136 SCNetworkReachability,
142 SCNetworkReachabilityRef
143);
144
145core_foundation::impl_TCFType!(
146 SCNetworkReachability,
147 SCNetworkReachabilityRef,
148 SCNetworkReachabilityGetTypeID
149);
150
151impl SCNetworkReachability {
152 pub fn from_addr_pair(local: SocketAddr, remote: SocketAddr) -> SCNetworkReachability {
158 let ptr = unsafe {
159 SCNetworkReachabilityCreateWithAddressPair(
160 std::ptr::null(),
161 &*to_c_sockaddr(local),
162 &*to_c_sockaddr(remote),
163 )
164 };
165
166 unsafe { Self::wrap_under_create_rule(ptr) }
167 }
168
169 pub fn from_host(host: &CStr) -> Option<Self> {
175 let ptr = unsafe { SCNetworkReachabilityCreateWithName(ptr::null(), host.as_ptr()) };
176 if ptr.is_null() {
177 None
178 } else {
179 unsafe { Some(Self::wrap_under_create_rule(ptr)) }
180 }
181 }
182
183 pub fn reachability(&self) -> Result<ReachabilityFlags, ReachabilityError> {
189 let mut raw_flags = 0u32;
190 if unsafe { SCNetworkReachabilityGetFlags(self.0, &mut raw_flags) } == 0u8 {
191 return Err(ReachabilityError::FailedToDetermineReachability);
192 }
193
194 ReachabilityFlags::from_bits(raw_flags)
195 .ok_or(ReachabilityError::UnrecognizedFlags(raw_flags))
196 }
197
198 pub unsafe fn schedule_with_runloop(
209 &self,
210 run_loop: &CFRunLoop,
211 run_loop_mode: CFStringRef,
212 ) -> Result<(), SchedulingError> {
213 if SCNetworkReachabilityScheduleWithRunLoop(
214 self.0,
215 run_loop.to_void() as *mut _,
216 run_loop_mode,
217 ) == 0u8
218 {
219 Err(SchedulingError(()))
220 } else {
221 Ok(())
222 }
223 }
224
225 pub unsafe fn unschedule_from_runloop(
236 &self,
237 run_loop: &CFRunLoop,
238 run_loop_mode: CFStringRef,
239 ) -> Result<(), UnschedulingError> {
240 if SCNetworkReachabilityUnscheduleFromRunLoop(
241 self.0,
242 run_loop.to_void() as *mut _,
243 run_loop_mode,
244 ) == 0u8
245 {
246 Err(UnschedulingError(()))
247 } else {
248 Ok(())
249 }
250 }
251
252 pub fn set_callback<F: Fn(ReachabilityFlags) + Sync + Send>(
260 &mut self,
261 callback: F,
262 ) -> Result<(), SetCallbackError> {
263 let callback = Arc::new(NetworkReachabilityCallbackContext::new(
264 self.clone(),
265 callback,
266 ));
267
268 let mut callback_context = SCNetworkReachabilityContext {
269 version: 0,
270 info: Arc::into_raw(callback) as *mut _,
271 retain: Some(NetworkReachabilityCallbackContext::<F>::retain_context),
272 release: Some(NetworkReachabilityCallbackContext::<F>::release_context),
273 copyDescription: Some(NetworkReachabilityCallbackContext::<F>::copy_ctx_description),
274 };
275
276 let result = unsafe {
277 SCNetworkReachabilitySetCallback(
278 self.0,
279 Some(NetworkReachabilityCallbackContext::<F>::callback),
280 &mut callback_context,
281 )
282 };
283
284 unsafe { Arc::decrement_strong_count(callback_context.info) };
294
295 if result == 0u8 {
296 Err(SetCallbackError {})
297 } else {
298 Ok(())
299 }
300 }
301}
302
303impl From<SocketAddr> for SCNetworkReachability {
304 fn from(addr: SocketAddr) -> Self {
305 unsafe {
306 let ptr =
307 SCNetworkReachabilityCreateWithAddress(std::ptr::null(), &*to_c_sockaddr(addr));
308 SCNetworkReachability::wrap_under_create_rule(ptr)
309 }
310 }
311}
312
313struct NetworkReachabilityCallbackContext<T: Fn(ReachabilityFlags) + Sync + Send> {
314 _host: SCNetworkReachability,
315 callback: T,
316}
317
318impl<T: Fn(ReachabilityFlags) + Sync + Send> NetworkReachabilityCallbackContext<T> {
319 fn new(host: SCNetworkReachability, callback: T) -> Self {
320 Self {
321 _host: host,
322 callback,
323 }
324 }
325
326 extern "C" fn callback(
327 _target: SCNetworkReachabilityRef,
328 flags: SCNetworkReachabilityFlags,
329 context: *mut c_void,
330 ) {
331 let context: &mut Self = unsafe { &mut (*(context as *mut _)) };
332 (context.callback)(ReachabilityFlags::from_bits_retain(flags));
333 }
334
335 extern "C" fn copy_ctx_description(_ctx: *const c_void) -> CFStringRef {
336 let description = CFString::from_static_string("NetworkRechability's callback context");
337 let description_ref = description.as_concrete_TypeRef();
338 std::mem::forget(description);
339 description_ref
340 }
341
342 extern "C" fn release_context(ctx: *const c_void) {
343 unsafe {
344 Arc::decrement_strong_count(ctx as *mut Self);
345 }
346 }
347
348 extern "C" fn retain_context(ctx_ptr: *const c_void) -> *const c_void {
349 unsafe {
350 Arc::increment_strong_count(ctx_ptr as *mut Self);
351 }
352 ctx_ptr
353 }
354}
355
356fn to_c_sockaddr(addr: SocketAddr) -> Box<libc::sockaddr> {
359 let ptr = match addr {
360 SocketAddr::V4(addr) => Box::into_raw(Box::new(libc::sockaddr_in {
364 sin_len: std::mem::size_of::<libc::sockaddr_in>() as u8,
365 sin_family: libc::AF_INET as libc::sa_family_t,
366 sin_port: addr.port().to_be(),
367 sin_addr: {
368 libc::in_addr {
372 s_addr: u32::from_ne_bytes(addr.ip().octets()),
373 }
374 },
375 sin_zero: Default::default(),
376 })) as *mut c_void,
377 SocketAddr::V6(addr) => Box::into_raw(Box::new(libc::sockaddr_in6 {
381 sin6_len: std::mem::size_of::<libc::sockaddr_in6>() as u8,
382 sin6_family: libc::AF_INET6 as libc::sa_family_t,
383 sin6_port: addr.port().to_be(),
384 sin6_flowinfo: addr.flowinfo(),
385 sin6_addr: libc::in6_addr {
386 s6_addr: addr.ip().octets(),
387 },
388 sin6_scope_id: addr.scope_id(),
389 })) as *mut c_void,
390 };
391
392 unsafe { Box::from_raw(ptr as *mut _) }
393}
394
395#[cfg(test)]
396mod test {
397 use super::*;
398
399 use core_foundation::runloop::{kCFRunLoopCommonModes, CFRunLoop};
400 use std::{
401 ffi::CString,
402 net::{Ipv4Addr, Ipv6Addr},
403 };
404
405 #[test]
406 fn test_network_reachability_from_addr() {
407 let sockaddrs = vec![
408 "0.0.0.0:0".parse::<SocketAddr>().unwrap(),
409 "[::0]:0".parse::<SocketAddr>().unwrap(),
410 ];
411
412 for addr in sockaddrs {
413 let mut reachability = SCNetworkReachability::from(addr);
414 assert!(
415 !reachability.0.is_null(),
416 "Failed to construct a SCNetworkReachability struct with {}",
417 addr
418 );
419 reachability.set_callback(|_| {}).unwrap();
420 unsafe {
422 reachability
423 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
424 .unwrap();
425 reachability
426 .unschedule_from_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
427 .unwrap();
428 }
429 }
430 }
431
432 #[test]
433 fn test_sockaddr_pair_reachability() {
434 let pairs = vec![
435 ("0.0.0.0:0", "[::0]:0"),
436 ("[::0]:0", "0.0.0.0:0"),
437 ("[::0]:0", "[::0]:0"),
438 ("0.0.0.0:0", "0.0.0.0:0"),
439 ]
440 .into_iter()
441 .map(|(a, b)| (a.parse().unwrap(), b.parse().unwrap()));
442
443 for (local, remote) in pairs {
444 let mut reachability = SCNetworkReachability::from_addr_pair(local, remote);
445 assert!(
446 !reachability.0.is_null(),
447 "Failed to construct a SCNetworkReachability struct with address pair {} - {}",
448 local,
449 remote
450 );
451 reachability.set_callback(|_| {}).unwrap();
452 unsafe {
454 reachability
455 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
456 .unwrap();
457 reachability
458 .unschedule_from_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
459 .unwrap();
460 }
461 }
462 }
463
464 #[test]
465 fn test_sockaddr_local_to_dns_google_pair_reachability() {
466 let sockaddrs = [
467 "[2001:4860:4860::8844]:443".parse::<SocketAddr>().unwrap(),
468 "8.8.4.4:443".parse().unwrap(),
469 ];
470 for remote_addr in sockaddrs {
471 match std::net::TcpStream::connect(remote_addr) {
472 Err(_) => {
473 let local_addr = if remote_addr.is_ipv4() {
474 SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
475 } else {
476 SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0)
477 };
478 let reachability =
479 SCNetworkReachability::from_addr_pair(local_addr, remote_addr);
480 let reachability_flags = reachability.reachability().unwrap();
481 assert!(!reachability_flags.contains(ReachabilityFlags::REACHABLE));
483 }
484 Ok(tcp) => {
485 let local = tcp.local_addr().unwrap();
486 let remote = tcp.peer_addr().unwrap();
487 let reachability = SCNetworkReachability::from_addr_pair(local, remote);
488 let reachability_flags = reachability.reachability().unwrap();
489 assert!(reachability_flags.contains(ReachabilityFlags::REACHABLE));
491 }
492 }
493 }
494 }
495
496 #[test]
497 fn test_reachability_ref_from_host() {
498 let valid_inputs = vec!["example.com", "host-in-local-network", "en0"];
499
500 let get_cstring = |input: &str| CString::new(input).unwrap();
501
502 for input in valid_inputs.into_iter().map(get_cstring) {
503 match SCNetworkReachability::from_host(&input) {
504 Some(mut reachability) => {
505 reachability.set_callback(|_| {}).unwrap();
506 unsafe {
508 reachability
509 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
510 .unwrap();
511 reachability
512 .unschedule_from_runloop(
513 &CFRunLoop::get_current(),
514 kCFRunLoopCommonModes,
515 )
516 .unwrap();
517 }
518 }
519 None => {
520 panic!(
521 "Failed to construct a SCNetworkReachability from {}",
522 input.to_string_lossy(),
523 );
524 }
525 }
526 }
527
528 assert!(
530 SCNetworkReachability::from_host(&get_cstring("")).is_none(),
531 "Constructed valid SCNetworkReachability from empty string"
532 );
533 }
534
535 unsafe impl Send for SCNetworkReachability {}
536
537 #[test]
538 fn assert_infallibility_of_setting_a_callback() {
539 let (tx, rx) = std::sync::mpsc::channel();
540 std::thread::spawn(move || {
541 let mut reachability =
542 SCNetworkReachability::from("0.0.0.0:0".parse::<SocketAddr>().unwrap());
543 reachability.set_callback(|_| {}).unwrap();
544 unsafe {
546 reachability
547 .schedule_with_runloop(&CFRunLoop::get_current(), kCFRunLoopCommonModes)
548 .unwrap();
549 }
550 reachability.set_callback(|_| {}).unwrap();
551 let _ = tx.send(reachability);
552 CFRunLoop::run_current();
553 });
554 let mut reachability = rx.recv().unwrap();
555 std::thread::sleep(std::time::Duration::from_secs(1));
556 reachability.set_callback(|_| {}).unwrap();
557 }
558}