1#![allow(renamed_and_removed_lints)] #![allow(unknown_lints)] #![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_duration_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] #![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_raw_string_hashes)] #![allow(clippy::needless_lifetimes)] #![allow(mismatched_lifetime_syntaxes)] #![cfg_attr(not(all(feature = "full")), allow(unused))]
49
50pub use crate::err::Error;
51use rangemap::RangeInclusiveMap;
52use std::fmt::{Debug, Display, Formatter};
53use std::net::{IpAddr, Ipv6Addr};
54use std::num::{NonZeroU8, NonZeroU32, TryFromIntError};
55use std::str::FromStr;
56use std::sync::{Arc, OnceLock};
57
58mod err;
59
60#[cfg(feature = "embedded-db")]
65static EMBEDDED_DB_V4: &str = include_str!("../data/geoip");
66
67#[cfg(feature = "embedded-db")]
69static EMBEDDED_DB_V6: &str = include_str!("../data/geoip6");
70
71#[cfg(feature = "embedded-db")]
73static EMBEDDED_DB_PARSED: OnceLock<Arc<GeoipDb>> = OnceLock::new();
74
75#[derive(Copy, Clone, Eq, PartialEq)]
90pub struct CountryCode {
91    inner: [NonZeroU8; 2],
99}
100
101impl CountryCode {
102    fn new(cc_orig: &str) -> Result<Self, Error> {
104        #[inline]
106        fn try_cvt_to_nz(inp: [u8; 2]) -> Result<[NonZeroU8; 2], TryFromIntError> {
107            Ok([inp[0].try_into()?, inp[1].try_into()?])
109        }
110
111        let cc = cc_orig.to_ascii_uppercase();
112
113        let cc: [u8; 2] = cc
114            .as_bytes()
115            .try_into()
116            .map_err(|_| Error::BadCountryCode(cc))?;
117
118        if !cc.iter().all(|b| b.is_ascii() && !b.is_ascii_control()) {
119            return Err(Error::BadCountryCode(cc_orig.to_owned()));
120        }
121
122        if &cc == b"??" {
123            return Err(Error::NowhereNotSupported);
124        }
125
126        Ok(Self {
127            inner: try_cvt_to_nz(cc).map_err(|_| Error::BadCountryCode(cc_orig.to_owned()))?,
128        })
129    }
130
131    pub fn get(&self) -> &str {
135        self.as_ref()
136    }
137}
138
139impl Display for CountryCode {
140    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
141        write!(f, "{}", self.as_ref())
142    }
143}
144
145impl Debug for CountryCode {
146    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147        write!(f, "CountryCode(\"{}\")", self.as_ref())
148    }
149}
150
151impl AsRef<str> for CountryCode {
152    fn as_ref(&self) -> &str {
153        #[inline]
156        fn cvt_ref(inp: &[NonZeroU8; 2]) -> &[u8; 2] {
157            let ptr = inp.as_ptr() as *const u8;
163            let slice = unsafe { std::slice::from_raw_parts(ptr, inp.len()) };
164            slice
165                .try_into()
166                .expect("the resulting slice should have the correct length!")
167        }
168
169        std::str::from_utf8(cvt_ref(&self.inner)).expect("invalid country code in CountryCode")
175    }
176}
177
178impl FromStr for CountryCode {
179    type Err = Error;
180
181    fn from_str(s: &str) -> Result<Self, Self::Err> {
182        CountryCode::new(s)
183    }
184}
185
186#[derive(
190    Copy, Clone, Debug, Eq, PartialEq, derive_more::Into, derive_more::From, derive_more::AsRef,
191)]
192#[allow(clippy::exhaustive_structs)]
193pub struct OptionCc(pub Option<CountryCode>);
194
195impl FromStr for OptionCc {
196    type Err = Error;
197
198    fn from_str(s: &str) -> Result<Self, Self::Err> {
199        match CountryCode::new(s) {
200            Err(Error::NowhereNotSupported) => Ok(None.into()),
201            Err(e) => Err(e),
202            Ok(cc) => Ok(Some(cc).into()),
203        }
204    }
205}
206
207impl Display for OptionCc {
208    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209        match self.0 {
210            Some(cc) => write!(f, "{}", cc),
211            None => write!(f, "??"),
212        }
213    }
214}
215
216#[derive(Copy, Clone, Eq, PartialEq, Debug)]
220struct NetDefn {
221    cc: Option<CountryCode>,
225    asn: Option<NonZeroU32>,
227}
228
229impl NetDefn {
230    fn new(cc: &str, asn: Option<u32>) -> Result<Self, Error> {
232        let asn = NonZeroU32::new(asn.unwrap_or(0));
233        let cc = cc.parse::<OptionCc>()?.into();
234
235        Ok(Self { cc, asn })
236    }
237
238    fn country_code(&self) -> Option<&CountryCode> {
240        self.cc.as_ref()
241    }
242
243    fn asn(&self) -> Option<u32> {
245        self.asn.as_ref().map(|x| x.get())
246    }
247}
248
249#[derive(Clone, Eq, PartialEq, Debug)]
251pub struct GeoipDb {
252    map_v4: RangeInclusiveMap<u32, NetDefn>,
254    map_v6: RangeInclusiveMap<u128, NetDefn>,
256}
257
258impl GeoipDb {
259    #[cfg(feature = "embedded-db")]
264    pub fn new_embedded() -> Arc<Self> {
265        Arc::clone(EMBEDDED_DB_PARSED.get_or_init(|| {
266            Arc::new(
267                Self::new_from_legacy_format(EMBEDDED_DB_V4, EMBEDDED_DB_V6)
269                    .expect("failed to parse embedded geoip database"),
270            )
271        }))
272    }
273
274    pub fn new_from_legacy_format(db_v4: &str, db_v6: &str) -> Result<Self, Error> {
276        let mut ret = GeoipDb {
277            map_v4: Default::default(),
278            map_v6: Default::default(),
279        };
280
281        for line in db_v4.lines() {
282            if line.starts_with('#') {
283                continue;
284            }
285            let line = line.trim();
286            if line.is_empty() {
287                continue;
288            }
289            let mut split = line.split(',');
290            let from = split
291                .next()
292                .ok_or(Error::BadFormat("empty line somehow?"))?
293                .parse::<u32>()?;
294            let to = split
295                .next()
296                .ok_or(Error::BadFormat("line with insufficient commas"))?
297                .parse::<u32>()?;
298            let cc = split
299                .next()
300                .ok_or(Error::BadFormat("line with insufficient commas"))?;
301            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
302
303            let defn = NetDefn::new(cc, asn)?;
304
305            ret.map_v4.insert(from..=to, defn);
306        }
307
308        for line in db_v6.lines() {
310            if line.starts_with('#') {
311                continue;
312            }
313            let line = line.trim();
314            if line.is_empty() {
315                continue;
316            }
317            let mut split = line.split(',');
318            let from = split
319                .next()
320                .ok_or(Error::BadFormat("empty line somehow?"))?
321                .parse::<Ipv6Addr>()?;
322            let to = split
323                .next()
324                .ok_or(Error::BadFormat("line with insufficient commas"))?
325                .parse::<Ipv6Addr>()?;
326            let cc = split
327                .next()
328                .ok_or(Error::BadFormat("line with insufficient commas"))?;
329            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
330
331            let defn = NetDefn::new(cc, asn)?;
332
333            ret.map_v6.insert(from.into()..=to.into(), defn);
334        }
335
336        Ok(ret)
337    }
338
339    fn lookup_defn(&self, ip: IpAddr) -> Option<&NetDefn> {
341        match ip {
342            IpAddr::V4(v4) => self.map_v4.get(&v4.into()),
343            IpAddr::V6(v6) => self.map_v6.get(&v6.into()),
344        }
345    }
346
347    pub fn lookup_country_code(&self, ip: IpAddr) -> Option<&CountryCode> {
349        self.lookup_defn(ip).and_then(|x| x.country_code())
350    }
351
352    pub fn lookup_country_code_multi<I>(&self, ips: I) -> Option<&CountryCode>
358    where
359        I: IntoIterator<Item = IpAddr>,
360    {
361        let mut ret = None;
362
363        for ip in ips {
364            if let Some(cc) = self.lookup_country_code(ip) {
365                if ret.is_some() && ret != Some(cc) {
368                    return None;
369                }
370
371                ret = Some(cc);
372            }
373        }
374
375        ret
376    }
377
378    pub fn lookup_asn(&self, ip: IpAddr) -> Option<u32> {
380        self.lookup_defn(ip)?.asn()
381    }
382}
383
384pub trait HasCountryCode {
386    fn country_code(&self) -> Option<CountryCode>;
397}
398
399#[cfg(test)]
400mod test {
401    #![allow(clippy::bool_assert_comparison)]
403    #![allow(clippy::clone_on_copy)]
404    #![allow(clippy::dbg_macro)]
405    #![allow(clippy::mixed_attributes_style)]
406    #![allow(clippy::print_stderr)]
407    #![allow(clippy::print_stdout)]
408    #![allow(clippy::single_char_pattern)]
409    #![allow(clippy::unwrap_used)]
410    #![allow(clippy::unchecked_duration_subtraction)]
411    #![allow(clippy::useless_vec)]
412    #![allow(clippy::needless_pass_by_value)]
413    use super::*;
416    use std::net::Ipv4Addr;
417
418    #[test]
420    #[cfg(feature = "embedded-db")]
421    fn embedded_db() {
422        let db = GeoipDb::new_embedded();
423
424        assert_eq!(
425            db.lookup_country_code(Ipv4Addr::new(8, 8, 8, 8).into())
426                .map(|x| x.as_ref()),
427            Some("US")
428        );
429
430        assert_eq!(
431            db.lookup_country_code("2001:4860:4860::8888".parse().unwrap())
432                .map(|x| x.as_ref()),
433            Some("US")
434        );
435    }
436
437    #[test]
438    fn basic_lookups() {
439        let src_v4 = r#"
440        16909056,16909311,GB
441        "#;
442        let src_v6 = r#"
443        fe80::,fe81::,US
444        dead:beef::,dead:ffff::,??
445        "#;
446        let db = GeoipDb::new_from_legacy_format(src_v4, src_v6).unwrap();
447
448        assert_eq!(
449            db.lookup_country_code(Ipv4Addr::new(1, 2, 3, 4).into())
450                .map(|x| x.as_ref()),
451            Some("GB")
452        );
453
454        assert_eq!(
455            db.lookup_country_code(Ipv4Addr::new(1, 1, 1, 1).into()),
456            None
457        );
458
459        assert_eq!(
460            db.lookup_country_code("fe80::dead:beef".parse().unwrap())
461                .map(|x| x.as_ref()),
462            Some("US")
463        );
464
465        assert_eq!(
466            db.lookup_country_code("fe81::dead:beef".parse().unwrap()),
467            None
468        );
469        assert_eq!(
470            db.lookup_country_code("dead:beef::1".parse().unwrap()),
471            None
472        );
473    }
474
475    #[test]
476    fn cc_parse() -> Result<(), Error> {
477        assert_eq!(CountryCode::from_str("us")?, CountryCode::from_str("US")?);
479        assert_eq!(CountryCode::from_str("UY")?, CountryCode::from_str("UY")?);
480
481        assert_eq!(CountryCode::from_str("A7")?, CountryCode::from_str("a7")?);
483        assert_eq!(CountryCode::from_str("xz")?, CountryCode::from_str("xz")?);
484
485        assert!(matches!(
487            CountryCode::from_str("z"),
488            Err(Error::BadCountryCode(_))
489        ));
490        assert!(matches!(
491            CountryCode::from_str("🐻❄️"),
492            Err(Error::BadCountryCode(_))
493        ));
494        assert!(matches!(
495            CountryCode::from_str("Sheboygan"),
496            Err(Error::BadCountryCode(_))
497        ));
498
499        assert!(matches!(
501            CountryCode::from_str("\r\n"),
502            Err(Error::BadCountryCode(_))
503        ));
504        assert!(matches!(
505            CountryCode::from_str("\0\0"),
506            Err(Error::BadCountryCode(_))
507        ));
508        assert!(matches!(
509            CountryCode::from_str("¡"),
510            Err(Error::BadCountryCode(_))
511        ));
512
513        assert!(matches!(
515            CountryCode::from_str("??"),
516            Err(Error::NowhereNotSupported)
517        ));
518
519        Ok(())
520    }
521
522    #[test]
523    fn opt_cc_parse() -> Result<(), Error> {
524        assert_eq!(
525            CountryCode::from_str("br")?,
526            OptionCc::from_str("BR")?.0.unwrap()
527        );
528        assert!(OptionCc::from_str("??")?.0.is_none());
529
530        Ok(())
531    }
532}