1use std::{
2 ffi::c_void,
3 net::SocketAddr,
4 sync::{Arc, Mutex},
5};
6
7use crate::{bridge, error::Result, ffi, SystemConfigurationError};
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10pub struct ReachabilityFlags(pub u32);
11
12impl ReachabilityFlags {
13 pub fn bits(self) -> u32 {
14 self.0
15 }
16
17 pub fn is_transient_connection(self) -> bool {
18 self.0 & (1 << 0) != 0
19 }
20
21 pub fn is_reachable(self) -> bool {
22 self.0 & (1 << 1) != 0
23 }
24
25 pub fn needs_connection(self) -> bool {
26 self.0 & (1 << 2) != 0
27 }
28
29 pub fn is_connection_on_traffic(self) -> bool {
30 self.0 & (1 << 3) != 0
31 }
32
33 pub fn needs_intervention(self) -> bool {
34 self.0 & (1 << 4) != 0
35 }
36
37 pub fn is_connection_on_demand(self) -> bool {
38 self.0 & (1 << 5) != 0
39 }
40
41 pub fn is_local_address(self) -> bool {
42 self.0 & (1 << 16) != 0
43 }
44
45 pub fn is_direct(self) -> bool {
46 self.0 & (1 << 17) != 0
47 }
48
49 pub fn is_wwan(self) -> bool {
50 self.0 & (1 << 18) != 0
51 }
52}
53
54impl std::fmt::Display for ReachabilityFlags {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 let mut labels = Vec::new();
57 if self.is_transient_connection() {
58 labels.push("transient");
59 }
60 if self.is_reachable() {
61 labels.push("reachable");
62 }
63 if self.needs_connection() {
64 labels.push("needs-connection");
65 }
66 if self.is_connection_on_traffic() {
67 labels.push("on-traffic");
68 }
69 if self.needs_intervention() {
70 labels.push("needs-intervention");
71 }
72 if self.is_connection_on_demand() {
73 labels.push("on-demand");
74 }
75 if self.is_local_address() {
76 labels.push("local-address");
77 }
78 if self.is_direct() {
79 labels.push("direct");
80 }
81 if self.is_wwan() {
82 labels.push("wwan");
83 }
84 if labels.is_empty() {
85 write!(f, "0x{:x}", self.bits())
86 } else {
87 write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
88 }
89 }
90}
91
92struct LocalCallbackState {
93 callback: Box<dyn FnMut(ReachabilityFlags)>,
94}
95
96struct SendCallbackState {
97 callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
98}
99
100enum RegisteredCallback {
101 Local {
102 _state: Box<LocalCallbackState>,
103 },
104 Send {
105 _state: Arc<Mutex<SendCallbackState>>,
106 },
107}
108
109unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
110 if info.is_null() {
111 return;
112 }
113
114 let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
115 (state.callback)(ReachabilityFlags(flags));
116}
117
118unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
119 if info.is_null() {
120 return;
121 }
122
123 let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
124 if let Ok(mut state) = mutex.lock() {
125 (state.callback)(ReachabilityFlags(flags));
126 }
127}
128
129pub struct Reachability {
130 raw: bridge::OwnedHandle,
131 callback: Option<RegisteredCallback>,
132 scheduled_with_current_run_loop: bool,
133 dispatch_queue_active: bool,
134}
135
136pub type NetworkReachability = Reachability;
137
138impl Reachability {
139 pub fn type_id() -> u64 {
140 unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
141 }
142
143 pub fn with_name(name: &str) -> Result<Self> {
144 let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
145 let raw =
146 unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
147 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
148 Ok(Self {
149 raw,
150 callback: None,
151 scheduled_with_current_run_loop: false,
152 dispatch_queue_active: false,
153 })
154 }
155
156 pub fn with_address(address: SocketAddr) -> Result<Self> {
157 let storage = socket_addr_to_bytes(address);
158 let raw = unsafe {
159 ffi::network_reachability::sc_reachability_create_with_address(
160 storage.as_ptr(),
161 isize::try_from(storage.len()).expect("socket address length exceeded isize"),
162 )
163 };
164 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
165 Ok(Self {
166 raw,
167 callback: None,
168 scheduled_with_current_run_loop: false,
169 dispatch_queue_active: false,
170 })
171 }
172
173 pub fn with_address_pair(
174 local_address: Option<SocketAddr>,
175 remote_address: Option<SocketAddr>,
176 ) -> Result<Self> {
177 let local = local_address.map(socket_addr_to_bytes);
178 let remote = remote_address.map(socket_addr_to_bytes);
179 let raw = unsafe {
180 ffi::network_reachability::sc_reachability_create_with_address_pair(
181 local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
182 local.as_ref().map_or(0, |value| {
183 isize::try_from(value.len()).expect("socket address length exceeded isize")
184 }),
185 remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
186 remote.as_ref().map_or(0, |value| {
187 isize::try_from(value.len()).expect("socket address length exceeded isize")
188 }),
189 )
190 };
191 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
192 Ok(Self {
193 raw,
194 callback: None,
195 scheduled_with_current_run_loop: false,
196 dispatch_queue_active: false,
197 })
198 }
199
200 pub fn flags(&self) -> Result<ReachabilityFlags> {
201 let mut flags = 0_u32;
202 let ok = unsafe {
203 ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
204 };
205 bridge::bool_result("sc_reachability_get_flags", ok)?;
206 Ok(ReachabilityFlags(flags))
207 }
208
209 pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
210 where
211 F: FnMut(ReachabilityFlags) + 'static,
212 {
213 if self.dispatch_queue_active {
214 return Err(SystemConfigurationError::null(
215 "sc_reachability_set_callback",
216 "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
217 ));
218 }
219
220 let mut callback = Box::new(LocalCallbackState {
221 callback: Box::new(callback),
222 });
223 self.set_registered_callback(
224 Some(reachability_callback_local),
225 std::ptr::from_mut(&mut *callback).cast::<c_void>(),
226 Some(RegisteredCallback::Local { _state: callback }),
227 )
228 }
229
230 pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
231 where
232 F: FnMut(ReachabilityFlags) + Send + 'static,
233 {
234 let callback = Arc::new(Mutex::new(SendCallbackState {
235 callback: Box::new(callback),
236 }));
237 self.set_registered_callback(
238 Some(reachability_callback_send),
239 Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
240 Some(RegisteredCallback::Send { _state: callback }),
241 )
242 }
243
244 pub fn clear_callback(&mut self) -> Result<()> {
245 if self.dispatch_queue_active {
246 self.clear_dispatch_queue()?;
247 }
248 self.set_registered_callback(None, std::ptr::null_mut(), None)
249 }
250
251 pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
252 let ok = unsafe {
253 ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
254 self.raw.as_ptr(),
255 )
256 };
257 bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
258 self.scheduled_with_current_run_loop = true;
259 Ok(())
260 }
261
262 pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
263 let ok = unsafe {
264 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
265 self.raw.as_ptr(),
266 )
267 };
268 bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
269 self.scheduled_with_current_run_loop = false;
270 Ok(())
271 }
272
273 pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
274 if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
275 return Err(SystemConfigurationError::null(
276 "sc_reachability_set_dispatch_queue_global",
277 "dispatch queues require callbacks registered via Reachability::set_callback_send",
278 ));
279 }
280
281 let ok = unsafe {
282 ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
283 };
284 bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
285 self.dispatch_queue_active = true;
286 Ok(())
287 }
288
289 pub fn clear_dispatch_queue(&mut self) -> Result<()> {
290 let ok = unsafe {
291 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
292 };
293 bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
294 self.dispatch_queue_active = false;
295 Ok(())
296 }
297
298 fn set_registered_callback(
299 &mut self,
300 callback: ffi::network_reachability::ReachabilityCallback,
301 info: *mut c_void,
302 registered: Option<RegisteredCallback>,
303 ) -> Result<()> {
304 let ok = unsafe {
305 ffi::network_reachability::sc_reachability_set_callback(
306 self.raw.as_ptr(),
307 callback,
308 info,
309 )
310 };
311 bridge::bool_result("sc_reachability_set_callback", ok)?;
312 self.callback = registered;
313 Ok(())
314 }
315}
316
317impl Drop for Reachability {
318 fn drop(&mut self) {
319 if self.dispatch_queue_active {
320 let _ = unsafe {
321 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
322 };
323 }
324 if self.scheduled_with_current_run_loop {
325 let _ = unsafe {
326 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
327 self.raw.as_ptr(),
328 )
329 };
330 }
331 if self.callback.is_some() {
332 let _ = unsafe {
333 ffi::network_reachability::sc_reachability_set_callback(
334 self.raw.as_ptr(),
335 None,
336 std::ptr::null_mut(),
337 )
338 };
339 }
340 }
341}
342
343fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
344 match address {
345 SocketAddr::V4(address) => {
346 let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
347 storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
348 .expect("sockaddr_in length exceeds u8");
349 storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
350 storage.sin_port = address.port().to_be();
351 storage.sin_addr = libc::in_addr {
352 s_addr: u32::from_ne_bytes(address.ip().octets()),
353 };
354 unsafe {
355 std::slice::from_raw_parts(
356 std::ptr::from_ref(&storage).cast::<u8>(),
357 std::mem::size_of::<libc::sockaddr_in>(),
358 )
359 .to_vec()
360 }
361 }
362 SocketAddr::V6(address) => {
363 let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
364 storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
365 .expect("sockaddr_in6 length exceeds u8");
366 storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
367 storage.sin6_port = address.port().to_be();
368 storage.sin6_flowinfo = address.flowinfo();
369 storage.sin6_scope_id = address.scope_id();
370 storage.sin6_addr = libc::in6_addr {
371 s6_addr: address.ip().octets(),
372 };
373 unsafe {
374 std::slice::from_raw_parts(
375 std::ptr::from_ref(&storage).cast::<u8>(),
376 std::mem::size_of::<libc::sockaddr_in6>(),
377 )
378 .to_vec()
379 }
380 }
381 }
382}