1use std::{
2 ffi::c_void,
3 net::SocketAddr,
4 panic::AssertUnwindSafe,
5 sync::{Arc, Mutex},
6};
7
8use crate::{bridge, error::Result, ffi, SystemConfigurationError};
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
11pub struct ReachabilityFlags(
13 pub u32,
15);
16
17impl ReachabilityFlags {
18 pub fn bits(self) -> u32 {
20 self.0
21 }
22
23 pub fn is_transient_connection(self) -> bool {
25 self.0 & (1 << 0) != 0
26 }
27
28 pub fn is_reachable(self) -> bool {
30 self.0 & (1 << 1) != 0
31 }
32
33 pub fn needs_connection(self) -> bool {
35 self.0 & (1 << 2) != 0
36 }
37
38 pub fn is_connection_on_traffic(self) -> bool {
40 self.0 & (1 << 3) != 0
41 }
42
43 pub fn needs_intervention(self) -> bool {
45 self.0 & (1 << 4) != 0
46 }
47
48 pub fn is_connection_on_demand(self) -> bool {
50 self.0 & (1 << 5) != 0
51 }
52
53 pub fn is_local_address(self) -> bool {
55 self.0 & (1 << 16) != 0
56 }
57
58 pub fn is_direct(self) -> bool {
60 self.0 & (1 << 17) != 0
61 }
62
63 pub fn is_wwan(self) -> bool {
65 self.0 & (1 << 18) != 0
66 }
67}
68
69impl std::fmt::Display for ReachabilityFlags {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 let mut labels = Vec::new();
72 if self.is_transient_connection() {
73 labels.push("transient");
74 }
75 if self.is_reachable() {
76 labels.push("reachable");
77 }
78 if self.needs_connection() {
79 labels.push("needs-connection");
80 }
81 if self.is_connection_on_traffic() {
82 labels.push("on-traffic");
83 }
84 if self.needs_intervention() {
85 labels.push("needs-intervention");
86 }
87 if self.is_connection_on_demand() {
88 labels.push("on-demand");
89 }
90 if self.is_local_address() {
91 labels.push("local-address");
92 }
93 if self.is_direct() {
94 labels.push("direct");
95 }
96 if self.is_wwan() {
97 labels.push("wwan");
98 }
99 if labels.is_empty() {
100 write!(f, "0x{:x}", self.bits())
101 } else {
102 write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
103 }
104 }
105}
106
107struct LocalCallbackState {
108 callback: Box<dyn FnMut(ReachabilityFlags)>,
109}
110
111struct SendCallbackState {
112 callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
113}
114
115enum RegisteredCallback {
116 Local {
117 _state: Box<LocalCallbackState>,
118 },
119 Send {
120 _state: Arc<Mutex<SendCallbackState>>,
121 },
122}
123
124unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
125 if info.is_null() {
126 return;
127 }
128
129 let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
130 let _ = std::panic::catch_unwind(AssertUnwindSafe(|| {
132 (state.callback)(ReachabilityFlags(flags));
133 }));
134}
135
136unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
137 if info.is_null() {
138 return;
139 }
140
141 let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
142 if let Ok(mut state) = mutex.lock() {
143 let _ = std::panic::catch_unwind(AssertUnwindSafe(|| {
145 (state.callback)(ReachabilityFlags(flags));
146 }));
147 }
148}
149
150pub struct Reachability {
152 raw: bridge::OwnedHandle,
153 callback: Option<RegisteredCallback>,
154 scheduled_with_current_run_loop: bool,
155 dispatch_queue_active: bool,
156}
157
158pub type NetworkReachability = Reachability;
160
161impl Reachability {
162 pub fn type_id() -> u64 {
164 unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
165 }
166
167 pub fn with_name(name: &str) -> Result<Self> {
169 let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
170 let raw =
171 unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
172 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
173 Ok(Self {
174 raw,
175 callback: None,
176 scheduled_with_current_run_loop: false,
177 dispatch_queue_active: false,
178 })
179 }
180
181 pub fn with_address(address: SocketAddr) -> Result<Self> {
183 let storage = socket_addr_to_bytes(address);
184 let raw = unsafe {
185 ffi::network_reachability::sc_reachability_create_with_address(
186 storage.as_ptr(),
187 isize::try_from(storage.len()).expect("socket address length exceeded isize"),
188 )
189 };
190 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
191 Ok(Self {
192 raw,
193 callback: None,
194 scheduled_with_current_run_loop: false,
195 dispatch_queue_active: false,
196 })
197 }
198
199 pub fn with_address_pair(
201 local_address: Option<SocketAddr>,
202 remote_address: Option<SocketAddr>,
203 ) -> Result<Self> {
204 let local = local_address.map(socket_addr_to_bytes);
205 let remote = remote_address.map(socket_addr_to_bytes);
206 let raw = unsafe {
207 ffi::network_reachability::sc_reachability_create_with_address_pair(
208 local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
209 local.as_ref().map_or(0, |value| {
210 isize::try_from(value.len()).expect("socket address length exceeded isize")
211 }),
212 remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
213 remote.as_ref().map_or(0, |value| {
214 isize::try_from(value.len()).expect("socket address length exceeded isize")
215 }),
216 )
217 };
218 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
219 Ok(Self {
220 raw,
221 callback: None,
222 scheduled_with_current_run_loop: false,
223 dispatch_queue_active: false,
224 })
225 }
226
227 pub fn flags(&self) -> Result<ReachabilityFlags> {
229 let mut flags = 0_u32;
230 let ok = unsafe {
231 ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
232 };
233 bridge::bool_result("sc_reachability_get_flags", ok)?;
234 Ok(ReachabilityFlags(flags))
235 }
236
237 pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
239 where
240 F: FnMut(ReachabilityFlags) + 'static,
241 {
242 if self.dispatch_queue_active {
243 return Err(SystemConfigurationError::null(
244 "sc_reachability_set_callback",
245 "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
246 ));
247 }
248
249 let mut callback = Box::new(LocalCallbackState {
250 callback: Box::new(callback),
251 });
252 self.set_registered_callback(
253 Some(reachability_callback_local),
254 std::ptr::from_mut(&mut *callback).cast::<c_void>(),
255 Some(RegisteredCallback::Local { _state: callback }),
256 )
257 }
258
259 pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
261 where
262 F: FnMut(ReachabilityFlags) + Send + 'static,
263 {
264 let callback = Arc::new(Mutex::new(SendCallbackState {
265 callback: Box::new(callback),
266 }));
267 self.set_registered_callback(
268 Some(reachability_callback_send),
269 Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
270 Some(RegisteredCallback::Send { _state: callback }),
271 )
272 }
273
274 pub fn clear_callback(&mut self) -> Result<()> {
276 if self.dispatch_queue_active {
277 self.clear_dispatch_queue()?;
278 }
279 self.set_registered_callback(None, std::ptr::null_mut(), None)
280 }
281
282 pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
284 let ok = unsafe {
285 ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
286 self.raw.as_ptr(),
287 )
288 };
289 bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
290 self.scheduled_with_current_run_loop = true;
291 Ok(())
292 }
293
294 pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
296 let ok = unsafe {
297 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
298 self.raw.as_ptr(),
299 )
300 };
301 bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
302 self.scheduled_with_current_run_loop = false;
303 Ok(())
304 }
305
306 pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
308 if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
309 return Err(SystemConfigurationError::null(
310 "sc_reachability_set_dispatch_queue_global",
311 "dispatch queues require callbacks registered via Reachability::set_callback_send",
312 ));
313 }
314
315 let ok = unsafe {
316 ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
317 };
318 bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
319 self.dispatch_queue_active = true;
320 Ok(())
321 }
322
323 pub fn clear_dispatch_queue(&mut self) -> Result<()> {
325 let ok = unsafe {
326 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
327 };
328 bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
329 self.dispatch_queue_active = false;
330 Ok(())
331 }
332
333 fn set_registered_callback(
334 &mut self,
335 callback: ffi::network_reachability::ReachabilityCallback,
336 info: *mut c_void,
337 registered: Option<RegisteredCallback>,
338 ) -> Result<()> {
339 let ok = unsafe {
340 ffi::network_reachability::sc_reachability_set_callback(
341 self.raw.as_ptr(),
342 callback,
343 info,
344 )
345 };
346 bridge::bool_result("sc_reachability_set_callback", ok)?;
347 self.callback = registered;
348 Ok(())
349 }
350}
351
352impl Drop for Reachability {
353 fn drop(&mut self) {
354 if self.dispatch_queue_active {
355 let _ = unsafe {
356 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
357 };
358 }
359 if self.scheduled_with_current_run_loop {
360 let _ = unsafe {
361 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
362 self.raw.as_ptr(),
363 )
364 };
365 }
366 if self.callback.is_some() {
367 let _ = unsafe {
368 ffi::network_reachability::sc_reachability_set_callback(
369 self.raw.as_ptr(),
370 None,
371 std::ptr::null_mut(),
372 )
373 };
374 }
375 }
376}
377
378fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
379 match address {
380 SocketAddr::V4(address) => {
381 let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
382 storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
383 .expect("sockaddr_in length exceeds u8");
384 storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
385 storage.sin_port = address.port().to_be();
386 storage.sin_addr = libc::in_addr {
387 s_addr: u32::from_ne_bytes(address.ip().octets()),
388 };
389 unsafe {
390 std::slice::from_raw_parts(
391 std::ptr::from_ref(&storage).cast::<u8>(),
392 std::mem::size_of::<libc::sockaddr_in>(),
393 )
394 .to_vec()
395 }
396 }
397 SocketAddr::V6(address) => {
398 let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
399 storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
400 .expect("sockaddr_in6 length exceeds u8");
401 storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
402 storage.sin6_port = address.port().to_be();
403 storage.sin6_flowinfo = address.flowinfo();
404 storage.sin6_scope_id = address.scope_id();
405 storage.sin6_addr = libc::in6_addr {
406 s6_addr: address.ip().octets(),
407 };
408 unsafe {
409 std::slice::from_raw_parts(
410 std::ptr::from_ref(&storage).cast::<u8>(),
411 std::mem::size_of::<libc::sockaddr_in6>(),
412 )
413 .to_vec()
414 }
415 }
416 }
417}