statsig_rust/evaluation/
country_lookup.rs

1use super::{dynamic_string::DynamicString, evaluator_context::EvaluatorContext};
2use crate::{
3    dyn_value, log_d, log_e, unwrap_or_return_with, user::StatsigUserInternal, DynamicValue,
4};
5use parking_lot::RwLock;
6use std::sync::Arc;
7
8pub struct CountryLookup;
9
10pub struct CountryLookupData {
11    country_codes: Vec<String>,
12    ip_ranges: Vec<i64>,
13}
14
15lazy_static::lazy_static! {
16    static ref COUNTRY_LOOKUP_DATA: Arc<RwLock<Option<CountryLookupData>>> = Arc::from(RwLock::from(None));
17    static ref IP: String = "ip".to_string();
18}
19
20const TAG: &str = "CountryLookup";
21const UNINITIALIZED_REASON: &str = "CountryLookupNotLoaded";
22
23pub trait UsizeExt {
24    fn post_inc(&mut self) -> Self;
25}
26
27impl UsizeExt for usize {
28    fn post_inc(&mut self) -> Self {
29        let was = *self;
30        *self += 1;
31        was
32    }
33}
34
35impl CountryLookup {
36    pub fn load_country_lookup() {
37        match COUNTRY_LOOKUP_DATA.try_read_for(std::time::Duration::from_secs(5)) {
38            Some(lock) => {
39                if lock.is_some() {
40                    log_d!(TAG, "Country Lookup already loaded");
41                    return;
42                }
43            }
44            None => {
45                log_e!(
46                    TAG,
47                    "Failed to acquire read lock on country lookup: Failed to lock COUNTRY_LOOKUP_DATA"
48                );
49                return;
50            }
51        }
52
53        let bytes = include_bytes!("../../resources/ip_supalite.table");
54
55        let mut raw_code_lookup: Vec<String> = vec![];
56        let mut country_codes: Vec<String> = vec![];
57        let mut ip_ranges: Vec<i64> = vec![];
58
59        let mut i = 0;
60
61        while i < bytes.len() {
62            let c1 = bytes[i.post_inc()] as char;
63            let c2 = bytes[i.post_inc()] as char;
64
65            raw_code_lookup.push(format!("{c1}{c2}"));
66
67            if c1 == '*' {
68                break;
69            }
70        }
71
72        let longs = |index: usize| bytes[index] as i64;
73
74        let mut last_end_range = 0_i64;
75        while (i + 1) < bytes.len() {
76            let mut count: i64 = 0;
77            let n1 = longs(i.post_inc());
78            if n1 < 240 {
79                count = n1;
80            } else if n1 == 242 {
81                let n2 = longs(i.post_inc());
82                let n3 = longs(i.post_inc());
83                count = n2 | (n3 << 8);
84            } else if n1 == 243 {
85                let n2 = longs(i.post_inc());
86                let n3 = longs(i.post_inc());
87                let n4 = longs(i.post_inc());
88                count = n2 | (n3 << 8) | (n4 << 16);
89            }
90
91            last_end_range += count * 256;
92
93            let cc = bytes[i.post_inc()] as usize;
94            ip_ranges.push(last_end_range);
95            country_codes.push(raw_code_lookup[cc].clone())
96        }
97
98        let country_lookup = CountryLookupData {
99            country_codes,
100            ip_ranges,
101        };
102
103        match COUNTRY_LOOKUP_DATA.try_write_for(std::time::Duration::from_secs(5)) {
104            Some(mut lock) => {
105                *lock = Some(country_lookup);
106                log_d!(TAG, " Successfully Loaded");
107            }
108            None => {
109                log_e!(
110                    TAG,
111                    "Failed to acquire write lock on country_lookup: Failed to lock COUNTRY_LOOKUP_DATA"
112                );
113            }
114        }
115    }
116
117    pub fn get_value_from_ip(
118        user: &StatsigUserInternal,
119        field: &Option<DynamicString>,
120        evaluator_context: &mut EvaluatorContext,
121    ) -> Option<DynamicValue> {
122        let unwrapped_field = match field {
123            Some(f) => f.value.as_str(),
124            _ => return None,
125        };
126
127        if unwrapped_field != "country" {
128            return None;
129        }
130
131        let ip = match user.get_user_value(&Some(DynamicString::from(IP.to_string()))) {
132            Some(v) => match &v.string_value {
133                Some(s) => &s.value,
134                _ => return None,
135            },
136            None => return None,
137        };
138
139        Self::lookup(ip, evaluator_context)
140    }
141
142    fn lookup(ip_address: &str, evaluator_context: &mut EvaluatorContext) -> Option<DynamicValue> {
143        let parts: Vec<&str> = ip_address.split('.').collect();
144        if parts.len() != 4 {
145            return None;
146        }
147
148        let lock = unwrap_or_return_with!(
149            COUNTRY_LOOKUP_DATA.try_read_for(std::time::Duration::from_secs(5)),
150            || {
151                evaluator_context.result.override_reason = Some(UNINITIALIZED_REASON);
152                log_e!(TAG, "Failed to acquire read lock on country lookup");
153                None
154            }
155        );
156
157        let country_lookup_data = unwrap_or_return_with!(lock.as_ref(), || {
158            evaluator_context.result.override_reason = Some(UNINITIALIZED_REASON);
159            log_e!(TAG, "Failed to load country lookup. Did you disable CountryLookup or did not wait for country lookup to init. Check StatsigOptions configuration");
160            None
161        });
162
163        let nums: Vec<Option<i64>> = parts.iter().map(|&x| x.parse().ok()).collect();
164        if let (Some(n0), Some(n1), Some(n2), Some(n3)) = (nums[0], nums[1], nums[2], nums[3]) {
165            let ip_number = (n0 * 256_i64.pow(3)) + (n1 << 16) + (n2 << 8) + n3;
166            return Self::lookup_numeric(ip_number, country_lookup_data);
167        }
168
169        None
170    }
171
172    fn lookup_numeric(
173        ip_address: i64,
174        country_lookup_data: &CountryLookupData,
175    ) -> Option<DynamicValue> {
176        let index = Self::binary_search(ip_address, country_lookup_data);
177        let cc = country_lookup_data.country_codes[index].clone();
178        if cc == "--" {
179            return None;
180        }
181        Some(dyn_value!(cc))
182    }
183
184    fn binary_search(value: i64, country_lookup_data: &CountryLookupData) -> usize {
185        let mut min = 0;
186        let mut max = country_lookup_data.ip_ranges.len();
187
188        while min < max {
189            let mid = (min + max) >> 1;
190            if country_lookup_data.ip_ranges[mid] <= value {
191                min = mid + 1;
192            } else {
193                max = mid;
194            }
195        }
196
197        min
198    }
199}