trailbase_extension/
geoip.rs1use arc_swap::ArcSwap;
2use maxminddb::{MaxMindDbError, Reader, geoip2};
3use rusqlite::Error;
4use rusqlite::functions::Context;
5use serde::{Deserialize, Serialize};
6use std::net::IpAddr;
7use std::path::Path;
8use std::sync::LazyLock;
9
10type MaxMindReader = Reader<Vec<u8>>;
11
12static READER: LazyLock<ArcSwap<Option<MaxMindReader>>> =
13 LazyLock::new(|| ArcSwap::from_pointee(None));
14
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
16pub struct City {
17 pub country_code: Option<String>,
18 pub name: Option<String>,
19 pub subdivisions: Option<Vec<String>>,
20}
21
22impl City {
23 fn from(city: &geoip2::City) -> Self {
24 return Self {
25 name: extract_city_name(city),
26 country_code: city
27 .country
28 .as_ref()
29 .and_then(|c| Some(c.iso_code?.to_string())),
30 subdivisions: extract_subdivision_names(city),
31 };
32 }
33}
34
35pub fn load_geoip_db(path: impl AsRef<Path>) -> Result<(), MaxMindDbError> {
36 let reader = Reader::open_readfile(path)?;
37 log::debug!("Loaded geoip DB: {:?}", reader.metadata);
38 READER.swap(Some(reader).into());
39 return Ok(());
40}
41
42pub fn has_geoip_db() -> bool {
43 return READER.load().is_some();
44}
45
46#[derive(Clone, Debug, PartialEq)]
47pub enum DatabaseType {
48 Unknown,
49 GeoLite2Country,
50 GeoLite2City,
51 GeoLite2ASN,
52}
53
54pub fn database_type() -> Option<DatabaseType> {
55 if let Some(ref reader) = **READER.load() {
56 return Some(match reader.metadata.database_type.as_str() {
57 "GeoLite2-Country" => DatabaseType::GeoLite2Country,
58 "GeoLite2-City" => DatabaseType::GeoLite2City,
59 "GeoLite2-ASN" => DatabaseType::GeoLite2ASN,
61 _ => DatabaseType::Unknown,
62 });
63 }
64 return None;
65}
66
67pub(crate) fn geoip_country(context: &Context) -> Result<Option<String>, Error> {
68 return geoip_extract(context, |reader, client_ip| {
69 if let Ok(Some(country)) = reader.lookup::<geoip2::Country>(client_ip) {
70 return Some(country.country?.iso_code?.to_string());
71 }
72
73 return None;
74 });
75}
76
77pub(crate) fn geoip_city_json(context: &Context) -> Result<Option<String>, Error> {
78 return geoip_extract(context, |reader, client_ip| {
79 if let Ok(Some(ref city)) = reader.lookup::<geoip2::City>(client_ip) {
80 return serde_json::to_string(&City::from(city)).ok();
81 }
82
83 return None;
84 });
85}
86
87pub(crate) fn geoip_city_name(context: &Context) -> Result<Option<String>, Error> {
88 return geoip_extract(context, |reader, client_ip| {
89 if let Ok(Some(ref city)) = reader.lookup::<geoip2::City>(client_ip) {
90 return extract_city_name(city);
91 }
92
93 return None;
94 });
95}
96
97#[inline]
98fn geoip_extract(
99 context: &Context,
100 f: impl Fn(&MaxMindReader, IpAddr) -> Option<String>,
101) -> Result<Option<String>, Error> {
102 #[cfg(debug_assertions)]
103 if context.len() != 1 {
104 return Err(Error::InvalidParameterCount(context.len(), 1));
105 }
106
107 let Some(text) = context.get_raw(0).as_str_or_null()? else {
108 return Ok(None);
109 };
110
111 if !text.is_empty() {
112 let client_ip: IpAddr = text.parse().map_err(|err| {
113 Error::UserFunctionError(format!("Parsing ip '{text:?}' failed: {err}").into())
114 })?;
115
116 if let Some(ref reader) = **READER.load() {
117 return Ok(f(reader, client_ip));
118 }
119 }
120
121 Ok(None)
122}
123
124fn extract_city_name(city: &geoip2::City) -> Option<String> {
125 return city.city.as_ref().and_then(|c| {
126 if let Some(ref names) = c.names {
127 if let Some(city_name) = names.get("en") {
128 return Some(city_name.to_string());
129 }
130
131 if let Some((_locale, city_name)) = names.first_key_value() {
132 return Some(city_name.to_string());
133 }
134 }
135 return None;
136 });
137}
138
139fn extract_subdivision_names(city: &geoip2::City) -> Option<Vec<String>> {
140 return city.subdivisions.as_ref().map(|divisions| {
141 return divisions
142 .iter()
143 .filter_map(|s| {
144 if let Some(ref names) = s.names {
145 if let Some(city_name) = names.get("en") {
146 return Some(city_name.to_string());
147 }
148
149 if let Some((_locale, city_name)) = names.first_key_value() {
150 return Some(city_name.to_string());
151 }
152 }
153 return None;
154 })
155 .collect();
156 });
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_explicit_jsonschema() {
165 let ip = "89.160.20.112";
166 let conn = crate::connect_sqlite(None, None).unwrap();
167
168 let cc: Option<String> = conn
169 .query_row(&format!("SELECT geoip_country('{ip}')"), (), |row| {
170 row.get(0)
171 })
172 .unwrap();
173
174 assert_eq!(cc, None);
175
176 load_geoip_db("testdata/GeoIP2-City-Test.mmdb").unwrap();
177
178 let cc: String = conn
179 .query_row(&format!("SELECT geoip_country('{ip}')"), (), |row| {
180 row.get(0)
181 })
182 .unwrap();
183
184 assert_eq!(cc, "SE");
185
186 let city_name: String = conn
187 .query_row(&format!("SELECT geoip_city_name('{ip}')"), (), |row| {
188 row.get(0)
189 })
190 .unwrap();
191
192 assert_eq!(city_name, "Linköping");
193
194 let city: City = conn
195 .query_row(&format!("SELECT geoip_city_json('{ip}')"), (), |row| {
196 return Ok(serde_json::from_str(&row.get::<_, String>(0).unwrap()).unwrap());
197 })
198 .unwrap();
199
200 assert_eq!(
201 city,
202 City {
203 country_code: Some("SE".to_string()),
204 name: Some("Linköping".to_string()),
205 subdivisions: Some(vec!["Östergötland County".to_string()]),
206 }
207 );
208
209 let cc: Option<String> = conn
210 .query_row(&format!("SELECT geoip_country('127.0.0.1')"), (), |row| {
211 row.get(0)
212 })
213 .unwrap();
214
215 assert_eq!(cc, None);
216 }
217}