Skip to main content

systemconfiguration/
network_reachability.rs

1use std::{
2    ffi::c_void,
3    net::SocketAddr,
4    sync::{Arc, Mutex},
5};
6
7use crate::{bridge, error::Result, ffi, SystemConfigurationError};
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10/// Wraps `SCNetworkReachabilityFlags`.
11pub struct ReachabilityFlags(
12    /// Wraps the raw `SCNetworkReachabilityFlags` bitfield.
13    pub u32,
14);
15
16impl ReachabilityFlags {
17    /// Wraps a helper on `SCNetworkReachabilityFlags`.
18    pub fn bits(self) -> u32 {
19        self.0
20    }
21
22    /// Wraps a helper on `SCNetworkReachabilityFlags`.
23    pub fn is_transient_connection(self) -> bool {
24        self.0 & (1 << 0) != 0
25    }
26
27    /// Wraps a helper on `SCNetworkReachabilityFlags`.
28    pub fn is_reachable(self) -> bool {
29        self.0 & (1 << 1) != 0
30    }
31
32    /// Wraps a helper on `SCNetworkReachabilityFlags`.
33    pub fn needs_connection(self) -> bool {
34        self.0 & (1 << 2) != 0
35    }
36
37    /// Wraps a helper on `SCNetworkReachabilityFlags`.
38    pub fn is_connection_on_traffic(self) -> bool {
39        self.0 & (1 << 3) != 0
40    }
41
42    /// Wraps a helper on `SCNetworkReachabilityFlags`.
43    pub fn needs_intervention(self) -> bool {
44        self.0 & (1 << 4) != 0
45    }
46
47    /// Wraps a helper on `SCNetworkReachabilityFlags`.
48    pub fn is_connection_on_demand(self) -> bool {
49        self.0 & (1 << 5) != 0
50    }
51
52    /// Wraps a helper on `SCNetworkReachabilityFlags`.
53    pub fn is_local_address(self) -> bool {
54        self.0 & (1 << 16) != 0
55    }
56
57    /// Wraps a helper on `SCNetworkReachabilityFlags`.
58    pub fn is_direct(self) -> bool {
59        self.0 & (1 << 17) != 0
60    }
61
62    /// Wraps a helper on `SCNetworkReachabilityFlags`.
63    pub fn is_wwan(self) -> bool {
64        self.0 & (1 << 18) != 0
65    }
66}
67
68impl std::fmt::Display for ReachabilityFlags {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        let mut labels = Vec::new();
71        if self.is_transient_connection() {
72            labels.push("transient");
73        }
74        if self.is_reachable() {
75            labels.push("reachable");
76        }
77        if self.needs_connection() {
78            labels.push("needs-connection");
79        }
80        if self.is_connection_on_traffic() {
81            labels.push("on-traffic");
82        }
83        if self.needs_intervention() {
84            labels.push("needs-intervention");
85        }
86        if self.is_connection_on_demand() {
87            labels.push("on-demand");
88        }
89        if self.is_local_address() {
90            labels.push("local-address");
91        }
92        if self.is_direct() {
93            labels.push("direct");
94        }
95        if self.is_wwan() {
96            labels.push("wwan");
97        }
98        if labels.is_empty() {
99            write!(f, "0x{:x}", self.bits())
100        } else {
101            write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
102        }
103    }
104}
105
106struct LocalCallbackState {
107    callback: Box<dyn FnMut(ReachabilityFlags)>,
108}
109
110struct SendCallbackState {
111    callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
112}
113
114enum RegisteredCallback {
115    Local {
116        _state: Box<LocalCallbackState>,
117    },
118    Send {
119        _state: Arc<Mutex<SendCallbackState>>,
120    },
121}
122
123unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
124    if info.is_null() {
125        return;
126    }
127
128    let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
129    (state.callback)(ReachabilityFlags(flags));
130}
131
132unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
133    if info.is_null() {
134        return;
135    }
136
137    let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
138    if let Ok(mut state) = mutex.lock() {
139        (state.callback)(ReachabilityFlags(flags));
140    }
141}
142
143/// Wraps `SCNetworkReachabilityRef`.
144pub struct Reachability {
145    raw: bridge::OwnedHandle,
146    callback: Option<RegisteredCallback>,
147    scheduled_with_current_run_loop: bool,
148    dispatch_queue_active: bool,
149}
150
151/// Alias for the `SCNetworkReachabilityRef` wrapper.
152pub type NetworkReachability = Reachability;
153
154impl Reachability {
155    /// Wraps `SCReachabilityGetTypeID`.
156    pub fn type_id() -> u64 {
157        unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
158    }
159
160    /// Wraps `SCReachabilityCreateWithName`.
161    pub fn with_name(name: &str) -> Result<Self> {
162        let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
163        let raw =
164            unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
165        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
166        Ok(Self {
167            raw,
168            callback: None,
169            scheduled_with_current_run_loop: false,
170            dispatch_queue_active: false,
171        })
172    }
173
174    /// Wraps `SCReachabilityCreateWithAddress`.
175    pub fn with_address(address: SocketAddr) -> Result<Self> {
176        let storage = socket_addr_to_bytes(address);
177        let raw = unsafe {
178            ffi::network_reachability::sc_reachability_create_with_address(
179                storage.as_ptr(),
180                isize::try_from(storage.len()).expect("socket address length exceeded isize"),
181            )
182        };
183        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
184        Ok(Self {
185            raw,
186            callback: None,
187            scheduled_with_current_run_loop: false,
188            dispatch_queue_active: false,
189        })
190    }
191
192    /// Wraps `SCReachabilityCreateWithAddressPair`.
193    pub fn with_address_pair(
194        local_address: Option<SocketAddr>,
195        remote_address: Option<SocketAddr>,
196    ) -> Result<Self> {
197        let local = local_address.map(socket_addr_to_bytes);
198        let remote = remote_address.map(socket_addr_to_bytes);
199        let raw = unsafe {
200            ffi::network_reachability::sc_reachability_create_with_address_pair(
201                local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
202                local.as_ref().map_or(0, |value| {
203                    isize::try_from(value.len()).expect("socket address length exceeded isize")
204                }),
205                remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
206                remote.as_ref().map_or(0, |value| {
207                    isize::try_from(value.len()).expect("socket address length exceeded isize")
208                }),
209            )
210        };
211        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
212        Ok(Self {
213            raw,
214            callback: None,
215            scheduled_with_current_run_loop: false,
216            dispatch_queue_active: false,
217        })
218    }
219
220    /// Wraps `SCReachabilityGetFlags`.
221    pub fn flags(&self) -> Result<ReachabilityFlags> {
222        let mut flags = 0_u32;
223        let ok = unsafe {
224            ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
225        };
226        bridge::bool_result("sc_reachability_get_flags", ok)?;
227        Ok(ReachabilityFlags(flags))
228    }
229
230    /// Wraps a helper on `SCNetworkReachabilityRef`.
231    pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
232    where
233        F: FnMut(ReachabilityFlags) + 'static,
234    {
235        if self.dispatch_queue_active {
236            return Err(SystemConfigurationError::null(
237                "sc_reachability_set_callback",
238                "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
239            ));
240        }
241
242        let mut callback = Box::new(LocalCallbackState {
243            callback: Box::new(callback),
244        });
245        self.set_registered_callback(
246            Some(reachability_callback_local),
247            std::ptr::from_mut(&mut *callback).cast::<c_void>(),
248            Some(RegisteredCallback::Local { _state: callback }),
249        )
250    }
251
252    /// Wraps a helper on `SCNetworkReachabilityRef`.
253    pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
254    where
255        F: FnMut(ReachabilityFlags) + Send + 'static,
256    {
257        let callback = Arc::new(Mutex::new(SendCallbackState {
258            callback: Box::new(callback),
259        }));
260        self.set_registered_callback(
261            Some(reachability_callback_send),
262            Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
263            Some(RegisteredCallback::Send { _state: callback }),
264        )
265    }
266
267    /// Wraps a helper on `SCNetworkReachabilityRef`.
268    pub fn clear_callback(&mut self) -> Result<()> {
269        if self.dispatch_queue_active {
270            self.clear_dispatch_queue()?;
271        }
272        self.set_registered_callback(None, std::ptr::null_mut(), None)
273    }
274
275    /// Wraps `SCReachabilityScheduleWithRunLoopCurrent`.
276    pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
277        let ok = unsafe {
278            ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
279                self.raw.as_ptr(),
280            )
281        };
282        bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
283        self.scheduled_with_current_run_loop = true;
284        Ok(())
285    }
286
287    /// Wraps `SCReachabilityUnscheduleFromRunLoopCurrent`.
288    pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
289        let ok = unsafe {
290            ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
291                self.raw.as_ptr(),
292            )
293        };
294        bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
295        self.scheduled_with_current_run_loop = false;
296        Ok(())
297    }
298
299    /// Wraps `SCReachabilitySetDispatchQueueGlobal`.
300    pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
301        if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
302            return Err(SystemConfigurationError::null(
303                "sc_reachability_set_dispatch_queue_global",
304                "dispatch queues require callbacks registered via Reachability::set_callback_send",
305            ));
306        }
307
308        let ok = unsafe {
309            ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
310        };
311        bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
312        self.dispatch_queue_active = true;
313        Ok(())
314    }
315
316    /// Wraps `SCReachabilityClearDispatchQueue`.
317    pub fn clear_dispatch_queue(&mut self) -> Result<()> {
318        let ok = unsafe {
319            ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
320        };
321        bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
322        self.dispatch_queue_active = false;
323        Ok(())
324    }
325
326    fn set_registered_callback(
327        &mut self,
328        callback: ffi::network_reachability::ReachabilityCallback,
329        info: *mut c_void,
330        registered: Option<RegisteredCallback>,
331    ) -> Result<()> {
332        let ok = unsafe {
333            ffi::network_reachability::sc_reachability_set_callback(
334                self.raw.as_ptr(),
335                callback,
336                info,
337            )
338        };
339        bridge::bool_result("sc_reachability_set_callback", ok)?;
340        self.callback = registered;
341        Ok(())
342    }
343}
344
345impl Drop for Reachability {
346    fn drop(&mut self) {
347        if self.dispatch_queue_active {
348            let _ = unsafe {
349                ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
350            };
351        }
352        if self.scheduled_with_current_run_loop {
353            let _ = unsafe {
354                ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
355                    self.raw.as_ptr(),
356                )
357            };
358        }
359        if self.callback.is_some() {
360            let _ = unsafe {
361                ffi::network_reachability::sc_reachability_set_callback(
362                    self.raw.as_ptr(),
363                    None,
364                    std::ptr::null_mut(),
365                )
366            };
367        }
368    }
369}
370
371fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
372    match address {
373        SocketAddr::V4(address) => {
374            let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
375            storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
376                .expect("sockaddr_in length exceeds u8");
377            storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
378            storage.sin_port = address.port().to_be();
379            storage.sin_addr = libc::in_addr {
380                s_addr: u32::from_ne_bytes(address.ip().octets()),
381            };
382            unsafe {
383                std::slice::from_raw_parts(
384                    std::ptr::from_ref(&storage).cast::<u8>(),
385                    std::mem::size_of::<libc::sockaddr_in>(),
386                )
387                .to_vec()
388            }
389        }
390        SocketAddr::V6(address) => {
391            let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
392            storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
393                .expect("sockaddr_in6 length exceeds u8");
394            storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
395            storage.sin6_port = address.port().to_be();
396            storage.sin6_flowinfo = address.flowinfo();
397            storage.sin6_scope_id = address.scope_id();
398            storage.sin6_addr = libc::in6_addr {
399                s6_addr: address.ip().octets(),
400            };
401            unsafe {
402                std::slice::from_raw_parts(
403                    std::ptr::from_ref(&storage).cast::<u8>(),
404                    std::mem::size_of::<libc::sockaddr_in6>(),
405                )
406                .to_vec()
407            }
408        }
409    }
410}