trailbase_extension/
geoip.rs

1use arc_swap::ArcSwap;
2use maxminddb::{MaxMindDbError, Reader, geoip2};
3use rusqlite::Error;
4use rusqlite::functions::Context;
5use serde::{Deserialize, Serialize};
6use std::net::IpAddr;
7use std::path::Path;
8use std::sync::LazyLock;
9
10type MaxMindReader = Reader<Vec<u8>>;
11
12static READER: LazyLock<ArcSwap<Option<MaxMindReader>>> =
13  LazyLock::new(|| ArcSwap::from_pointee(None));
14
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
16pub struct City {
17  pub country_code: Option<String>,
18  pub name: Option<String>,
19  pub subdivisions: Option<Vec<String>>,
20}
21
22impl City {
23  fn from(city: &geoip2::City) -> Self {
24    return Self {
25      name: extract_city_name(city),
26      country_code: city
27        .country
28        .as_ref()
29        .and_then(|c| Some(c.iso_code?.to_string())),
30      subdivisions: extract_subdivision_names(city),
31    };
32  }
33}
34
35pub fn load_geoip_db(path: impl AsRef<Path>) -> Result<(), MaxMindDbError> {
36  let reader = Reader::open_readfile(path)?;
37  log::debug!("Loaded geoip DB: {:?}", reader.metadata);
38  READER.swap(Some(reader).into());
39  return Ok(());
40}
41
42pub fn has_geoip_db() -> bool {
43  return READER.load().is_some();
44}
45
46#[derive(Clone, Debug, PartialEq)]
47pub enum DatabaseType {
48  Unknown,
49  GeoLite2Country,
50  GeoLite2City,
51  GeoLite2ASN,
52}
53
54pub fn database_type() -> Option<DatabaseType> {
55  if let Some(ref reader) = **READER.load() {
56    return Some(match reader.metadata.database_type.as_str() {
57      "GeoLite2-Country" => DatabaseType::GeoLite2Country,
58      "GeoLite2-City" => DatabaseType::GeoLite2City,
59      // Autonomous system number.
60      "GeoLite2-ASN" => DatabaseType::GeoLite2ASN,
61      _ => DatabaseType::Unknown,
62    });
63  }
64  return None;
65}
66
67pub(crate) fn geoip_country(context: &Context) -> Result<Option<String>, Error> {
68  return geoip_extract(context, |reader, client_ip| {
69    if let Ok(Some(country)) = reader.lookup::<geoip2::Country>(client_ip) {
70      return Some(country.country?.iso_code?.to_string());
71    }
72
73    return None;
74  });
75}
76
77pub(crate) fn geoip_city_json(context: &Context) -> Result<Option<String>, Error> {
78  return geoip_extract(context, |reader, client_ip| {
79    if let Ok(Some(ref city)) = reader.lookup::<geoip2::City>(client_ip) {
80      return serde_json::to_string(&City::from(city)).ok();
81    }
82
83    return None;
84  });
85}
86
87pub(crate) fn geoip_city_name(context: &Context) -> Result<Option<String>, Error> {
88  return geoip_extract(context, |reader, client_ip| {
89    if let Ok(Some(ref city)) = reader.lookup::<geoip2::City>(client_ip) {
90      return extract_city_name(city);
91    }
92
93    return None;
94  });
95}
96
97#[inline]
98fn geoip_extract(
99  context: &Context,
100  f: impl Fn(&MaxMindReader, IpAddr) -> Option<String>,
101) -> Result<Option<String>, Error> {
102  #[cfg(debug_assertions)]
103  if context.len() != 1 {
104    return Err(Error::InvalidParameterCount(context.len(), 1));
105  }
106
107  let Some(text) = context.get_raw(0).as_str_or_null()? else {
108    return Ok(None);
109  };
110
111  if !text.is_empty() {
112    let client_ip: IpAddr = text.parse().map_err(|err| {
113      Error::UserFunctionError(format!("Parsing ip '{text:?}' failed: {err}").into())
114    })?;
115
116    if let Some(ref reader) = **READER.load() {
117      return Ok(f(reader, client_ip));
118    }
119  }
120
121  Ok(None)
122}
123
124fn extract_city_name(city: &geoip2::City) -> Option<String> {
125  return city.city.as_ref().and_then(|c| {
126    if let Some(ref names) = c.names {
127      if let Some(city_name) = names.get("en") {
128        return Some(city_name.to_string());
129      }
130
131      if let Some((_locale, city_name)) = names.first_key_value() {
132        return Some(city_name.to_string());
133      }
134    }
135    return None;
136  });
137}
138
139fn extract_subdivision_names(city: &geoip2::City) -> Option<Vec<String>> {
140  return city.subdivisions.as_ref().map(|divisions| {
141    return divisions
142      .iter()
143      .filter_map(|s| {
144        if let Some(ref names) = s.names {
145          if let Some(city_name) = names.get("en") {
146            return Some(city_name.to_string());
147          }
148
149          if let Some((_locale, city_name)) = names.first_key_value() {
150            return Some(city_name.to_string());
151          }
152        }
153        return None;
154      })
155      .collect();
156  });
157}
158
159#[cfg(test)]
160mod tests {
161  use super::*;
162
163  #[test]
164  fn test_explicit_jsonschema() {
165    let ip = "89.160.20.112";
166    let conn = crate::connect_sqlite(None, None).unwrap();
167
168    let cc: Option<String> = conn
169      .query_row(&format!("SELECT geoip_country('{ip}')"), (), |row| {
170        row.get(0)
171      })
172      .unwrap();
173
174    assert_eq!(cc, None);
175
176    load_geoip_db("testdata/GeoIP2-City-Test.mmdb").unwrap();
177
178    let cc: String = conn
179      .query_row(&format!("SELECT geoip_country('{ip}')"), (), |row| {
180        row.get(0)
181      })
182      .unwrap();
183
184    assert_eq!(cc, "SE");
185
186    let city_name: String = conn
187      .query_row(&format!("SELECT geoip_city_name('{ip}')"), (), |row| {
188        row.get(0)
189      })
190      .unwrap();
191
192    assert_eq!(city_name, "Linköping");
193
194    let city: City = conn
195      .query_row(&format!("SELECT geoip_city_json('{ip}')"), (), |row| {
196        return Ok(serde_json::from_str(&row.get::<_, String>(0).unwrap()).unwrap());
197      })
198      .unwrap();
199
200    assert_eq!(
201      city,
202      City {
203        country_code: Some("SE".to_string()),
204        name: Some("Linköping".to_string()),
205        subdivisions: Some(vec!["Östergötland County".to_string()]),
206      }
207    );
208
209    let cc: Option<String> = conn
210      .query_row(&format!("SELECT geoip_country('127.0.0.1')"), (), |row| {
211        row.get(0)
212      })
213      .unwrap();
214
215    assert_eq!(cc, None);
216  }
217}