Skip to main content

veilid_tools/
tools.rs

1use super::*;
2
3use std::io;
4use std::path::Path;
5
6//////////////////////////////////////////////////////////////////////////////////////////////////////////////
7
8#[macro_export]
9macro_rules! assert_err {
10    ($ex:expr) => {
11        if let Ok(v) = $ex {
12            panic!("assertion failed, expected Err(..), got {:?}", v);
13        }
14    };
15}
16
17#[macro_export]
18macro_rules! io_error_other {
19    ($msg:expr) => {
20        io::Error::new(io::ErrorKind::Other, $msg.to_string())
21    };
22}
23
24pub fn to_io_error_other<E: std::error::Error + Send + Sync + 'static>(x: E) -> io::Error {
25    io::Error::other(x)
26}
27
28#[macro_export]
29macro_rules! bail_io_error_other {
30    ($msg:expr) => {
31        return io::Result::Err(io::Error::new(io::ErrorKind::Other, $msg.to_string()))
32    };
33}
34
35//////////////////////////////////////////////////////////////////////////////////////////////////////////////
36
37cfg_if! {
38    if #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] {
39        #[must_use]
40        pub fn get_concurrency() -> u32 {
41            std::thread::available_parallelism()
42                .map(|x| x.get())
43                .unwrap_or_else(|e| {
44                    warn!("unable to get concurrency defaulting to single core: {}", e);
45                    1
46                }) as u32
47        }
48    }
49}
50
51//////////////////////////////////////////////////////////////////////////////////////////////////////////////
52
53pub fn split_port(name: &str) -> Result<(String, Option<u16>), String> {
54    if let Some(split) = name.rfind(':') {
55        let hoststr = &name[0..split];
56        let portstr = &name[split + 1..];
57        let port: u16 = portstr
58            .parse::<u16>()
59            .map_err(|e| format!("invalid port: {}", e))?;
60
61        Ok((hoststr.to_string(), Some(port)))
62    } else {
63        Ok((name.to_string(), None))
64    }
65}
66
67#[must_use]
68pub fn prepend_slash(s: String) -> String {
69    if s.starts_with('/') {
70        return s;
71    }
72    let mut out = "/".to_owned();
73    out.push_str(s.as_str());
74    out
75}
76
77#[must_use]
78pub fn timestamp_duration_to_secs(dur: u64) -> f64 {
79    // Downshift precision until it fits, lose least significant bits
80    let mut mul: f64 = 1.0f64 / 1_000_000.0f64;
81    let mut usec = dur;
82    while usec > (u32::MAX as u64) {
83        usec >>= 1;
84        mul *= 2.0f64;
85    }
86    f64::from(usec as u32) * mul
87}
88
89#[must_use]
90pub fn secs_to_timestamp_duration(secs: f64) -> u64 {
91    (secs * 1000000.0f64) as u64
92}
93
94#[must_use]
95pub fn ms_to_us(ms: u32) -> u64 {
96    (ms as u64) * 1000u64
97}
98
99pub fn us_to_ms(us: u64) -> Result<u32, String> {
100    u32::try_from(us / 1000u64).map_err(|e| format!("could not convert microseconds: {}", e))
101}
102
103// Calculate retry attempt with logarhythmic falloff
104#[must_use]
105pub fn retry_falloff_log(
106    last_us: u64,
107    cur_us: u64,
108    interval_start_us: u64,
109    interval_max_us: u64,
110    interval_multiplier_us: f64,
111) -> bool {
112    //
113    if cur_us < interval_start_us {
114        // Don't require a retry within the first 'interval_start_us' microseconds of the reliable time period
115        false
116    } else if cur_us >= last_us + interval_max_us {
117        // Retry at least every 'interval_max_us' microseconds
118        true
119    } else {
120        // Exponential falloff between 'interval_start_us' and 'interval_max_us' microseconds
121        last_us
122            <= secs_to_timestamp_duration(
123                timestamp_duration_to_secs(cur_us) / interval_multiplier_us,
124            )
125    }
126}
127
128pub fn try_at_most_n_things<T, I, C, R>(max: usize, things: I, closure: C) -> Option<R>
129where
130    I: IntoIterator<Item = T>,
131    C: Fn(T) -> Option<R>,
132{
133    let mut fails = 0usize;
134    for thing in things.into_iter() {
135        if let Some(r) = closure(thing) {
136            return Some(r);
137        }
138        fails += 1;
139        if fails >= max {
140            break;
141        }
142    }
143    None
144}
145
146pub async fn async_try_at_most_n_things<T, I, C, R, F>(
147    max: usize,
148    things: I,
149    closure: C,
150) -> Option<R>
151where
152    I: IntoIterator<Item = T>,
153    C: Fn(T) -> F,
154    F: Future<Output = Option<R>>,
155{
156    let mut fails = 0usize;
157    for thing in things.into_iter() {
158        if let Some(r) = closure(thing).await {
159            return Some(r);
160        }
161        fails += 1;
162        if fails >= max {
163            break;
164        }
165    }
166    None
167}
168
169pub trait CmpAssign {
170    fn min_assign(&mut self, other: Self);
171    fn max_assign(&mut self, other: Self);
172}
173
174impl<T> CmpAssign for T
175where
176    T: core::cmp::Ord,
177{
178    fn min_assign(&mut self, other: Self) {
179        if &other < self {
180            *self = other;
181        }
182    }
183    fn max_assign(&mut self, other: Self) {
184        if &other > self {
185            *self = other;
186        }
187    }
188}
189
190#[must_use]
191pub fn compatible_unspecified_socket_addr(socket_addr: &SocketAddr) -> SocketAddr {
192    match socket_addr {
193        SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
194        SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0),
195    }
196}
197
198cfg_if! {
199    if #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] {
200        use std::net::UdpSocket;
201
202        static IPV6_IS_SUPPORTED: Mutex<Option<bool>> = Mutex::new(None);
203
204        pub fn is_ipv6_supported() -> bool {
205            let mut opt_supp = IPV6_IS_SUPPORTED.lock();
206            if let Some(supp) = *opt_supp {
207                return supp;
208            }
209            // Not exhaustive but for our use case it should be sufficient. If no local ports are available for binding, Veilid isn't going to work anyway :P
210            let supp = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0)).is_ok();
211            *opt_supp = Some(supp);
212            supp
213        }
214    }
215}
216
217#[must_use]
218pub fn available_unspecified_addresses() -> Vec<IpAddr> {
219    if is_ipv6_supported() {
220        vec![
221            IpAddr::V4(Ipv4Addr::UNSPECIFIED),
222            IpAddr::V6(Ipv6Addr::UNSPECIFIED),
223        ]
224    } else {
225        vec![IpAddr::V4(Ipv4Addr::UNSPECIFIED)]
226    }
227}
228
229pub fn listen_address_to_socket_addrs(listen_address: &str) -> Result<Vec<SocketAddr>, String> {
230    // If no address is specified, but the port is, use ipv4 and ipv6 unspecified
231    // If the address is specified, only use the specified port and fail otherwise
232
233    let ip_addrs = available_unspecified_addresses();
234
235    Ok(if let Some(portstr) = listen_address.strip_prefix(':') {
236        let port = portstr
237            .parse::<u16>()
238            .map_err(|e| format!("Invalid port format in udp listen address: {}", e))?;
239        ip_addrs.iter().map(|a| SocketAddr::new(*a, port)).collect()
240    } else if let Ok(port) = listen_address.parse::<u16>() {
241        ip_addrs.iter().map(|a| SocketAddr::new(*a, port)).collect()
242    } else {
243        let listen_address_with_port = if listen_address.contains(':') {
244            listen_address.to_string()
245        } else {
246            format!("{}:0", listen_address)
247        };
248        cfg_if! {
249            if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
250                use core::str::FromStr;
251                vec![SocketAddr::from_str(&listen_address_with_port).map_err(|e| format!("Unable to parse address: {}",e))?]
252            } else {
253                listen_address_with_port
254                    .to_socket_addrs()
255                    .map_err(|e| format!("Unable to resolve address: {}", e))?
256                    .collect()
257            }
258        }
259    })
260}
261
262/// Dedup, but doesn't require a sorted vec, and keeps the element order
263pub trait RemoveDuplicates<T: PartialEq + Clone> {
264    fn remove_duplicates(&mut self);
265}
266
267impl<T: PartialEq + Ord + Clone> RemoveDuplicates<T> for Vec<T> {
268    fn remove_duplicates(&mut self) {
269        let mut already_seen = BTreeSet::new();
270        self.retain(move |item| already_seen.insert(item.clone()))
271    }
272}
273
274/// Check for duplicates but doesn't require a sorted vec
275pub trait HasDuplicates<T: PartialEq + Clone> {
276    fn has_duplicates(&self) -> bool;
277}
278
279impl<T: PartialEq + Ord + Clone> HasDuplicates<T> for Vec<T> {
280    fn has_duplicates(&self) -> bool {
281        let mut already_seen = BTreeSet::new();
282        for item in self.iter() {
283            if !already_seen.insert(item) {
284                return true;
285            }
286        }
287        false
288    }
289}
290
291cfg_if::cfg_if! {
292    if #[cfg(unix)] {
293        use std::os::unix::fs::MetadataExt;
294        use std::os::unix::prelude::PermissionsExt;
295
296        pub fn ensure_file_private_owner<P:AsRef<Path>>(path: P) -> Result<(), String>
297        {
298            let path = path.as_ref();
299            if !path.is_file() {
300                return Ok(());
301            }
302
303            let uid = unsafe { libc::geteuid() };
304            let gid = unsafe { libc::getegid() };
305            let meta = std::fs::metadata(path).map_err(|e| format!("unable to get metadata for path: {}", e))?;
306
307            if meta.mode() != 0o600 {
308                std::fs::set_permissions(path,std::fs::Permissions::from_mode(0o600)).map_err(|e| format!("unable to set correct permissions on path: {}", e))?;
309            }
310            if meta.uid() != uid || meta.gid() != gid {
311                return Err("path has incorrect owner/group".to_owned());
312            }
313            Ok(())
314        }
315
316        pub fn ensure_directory_private_owner<P:AsRef<Path>>(path: P, group_read: bool) -> Result<(), String>
317        {
318            let path = path.as_ref();
319            if !path.is_dir() {
320                return Ok(());
321            }
322
323            let uid = unsafe { libc::geteuid() };
324            let gid = unsafe { libc::getegid() };
325            let meta = std::fs::metadata(path).map_err(|e| format!("unable to get metadata for path: {}", e))?;
326
327            let perm = if group_read {
328                0o750
329            } else {
330                0o700
331            };
332
333            if meta.mode() != perm {
334                std::fs::set_permissions(path,std::fs::Permissions::from_mode(perm)).map_err(|e| format!("unable to set correct permissions on path: {}", e))?;
335            }
336            if meta.uid() != uid || meta.gid() != gid {
337                return Err("path has incorrect owner/group".to_owned());
338            }
339            Ok(())
340        }
341    } else if #[cfg(windows)] {
342        //use std::os::windows::fs::MetadataExt;
343        //use windows_permissions::*;
344
345        pub fn ensure_file_private_owner<P:AsRef<Path>>(path: P) -> Result<(), String>
346        {
347            let path = path.as_ref();
348            if !path.is_file() {
349                return Ok(());
350            }
351
352            Ok(())
353        }
354
355        pub fn ensure_directory_private_owner<P:AsRef<Path>>(path: P, _group_read: bool) -> Result<(), String>
356        {
357            let path = path.as_ref();
358            if !path.is_dir() {
359                return Ok(());
360            }
361
362            Ok(())
363        }
364
365    } else {
366        pub fn ensure_file_private_owner<P:AsRef<Path>>(path: P) -> Result<(), String>
367        {
368            let path = path.as_ref();
369            if !path.is_file() {
370                return Ok(());
371            }
372
373            Ok(())
374        }
375
376        pub fn ensure_directory_private_owner<P:AsRef<Path>>(path: P, _group_read: bool) -> Result<(), String>
377        {
378            let path = path.as_ref();
379            if !path.is_dir() {
380                return Ok(());
381            }
382
383            Ok(())
384        }
385    }
386}
387
388#[repr(C, align(8))]
389struct AlignToEight([u8; 8]);
390
391/// # Safety
392/// Ensure you immediately initialize this vector as it could contain sensitive data
393#[must_use]
394pub unsafe fn aligned_8_u8_vec_uninit(n_bytes: usize) -> Vec<u8> {
395    let n_units = n_bytes.div_ceil(mem::size_of::<AlignToEight>());
396    let mut aligned: Vec<AlignToEight> = Vec::with_capacity(n_units);
397    let ptr = aligned.as_mut_ptr();
398    let cap_units = aligned.capacity();
399    mem::forget(aligned);
400
401    Vec::from_raw_parts(
402        ptr as *mut u8,
403        n_bytes,
404        cap_units * mem::size_of::<AlignToEight>(),
405    )
406}
407
408/// # Safety
409/// Ensure you immediately initialize this vector as it could contain sensitive data
410#[must_use]
411pub unsafe fn unaligned_u8_vec_uninit(n_bytes: usize) -> Vec<u8> {
412    let mut unaligned: Vec<u8> = Vec::with_capacity(n_bytes);
413    let ptr = unaligned.as_mut_ptr();
414    mem::forget(unaligned);
415
416    Vec::from_raw_parts(ptr, n_bytes, n_bytes)
417}
418
419pub fn type_name_of_val<T: ?Sized>(_val: &T) -> &'static str {
420    std::any::type_name::<T>()
421}
422
423pub fn map_to_string<X: ToString>(arg: X) -> String {
424    arg.to_string()
425}
426
427//////////////////////////////////////////////////////////////////////////////////////////////////////////////
428
429pub struct DebugGuard {
430    name: &'static str,
431    counter: &'static AtomicUsize,
432}
433
434impl DebugGuard {
435    pub fn new(name: &'static str, counter: &'static AtomicUsize) -> Self {
436        let c = counter.fetch_add(1, Ordering::SeqCst);
437        eprintln!("{} entered: {}", name, c + 1);
438        Self { name, counter }
439    }
440}
441
442impl Drop for DebugGuard {
443    fn drop(&mut self) {
444        let c = self.counter.fetch_sub(1, Ordering::SeqCst);
445        eprintln!("{} exited: {}", self.name, c - 1);
446    }
447}