statsig_rust/
sampling_processor.rs

1use crate::evaluation::evaluation_types::AnyEvaluation;
2use crate::global_configs::GlobalConfigs;
3use crate::hashing::HashUtil;
4use crate::hashset_with_ttl::HashSetWithTTL;
5use crate::statsig_user_internal::StatsigUserInternal;
6use crate::{DynamicValue, StatsigErr, StatsigRuntime};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::time::Duration;
11
12const SPECIAL_CASE_RULES: [&str; 3] = ["disabled", "default", ""];
13const TTL_IN_SECONDS: u64 = 60;
14
15#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
16#[serde(rename_all = "lowercase")]
17pub enum SamplingStatus {
18    Logged,
19    Dropped,
20    #[default]
21    None,
22}
23
24#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26pub enum SamplingMode {
27    On,
28    Shadow,
29    #[default]
30    None,
31}
32
33#[derive(Default)]
34pub struct SamplingDecision {
35    pub should_send_exposure: bool,
36    pub sampling_rate: Option<u64>,
37    pub sampling_status: SamplingStatus,
38    pub sampling_mode: SamplingMode,
39}
40
41impl SamplingDecision {
42    pub fn new(
43        should_send_exposure: bool,
44        sampling_rate: Option<u64>,
45        sampling_status: SamplingStatus,
46        sampling_mode: SamplingMode,
47    ) -> Self {
48        Self {
49            should_send_exposure,
50            sampling_rate,
51            sampling_status,
52            sampling_mode,
53        }
54    }
55
56    pub fn force_logged() -> Self {
57        Self {
58            should_send_exposure: true,
59            sampling_rate: None,
60            sampling_status: SamplingStatus::None,
61            sampling_mode: SamplingMode::None,
62        }
63    }
64}
65
66pub struct SamplingProcessor {
67    sampling_key_set: HashSetWithTTL,
68    hashing: Arc<HashUtil>,
69    global_configs: Arc<GlobalConfigs>,
70}
71
72impl SamplingProcessor {
73    pub fn new(
74        statsig_runtime: &Arc<StatsigRuntime>,
75        hashing: Arc<HashUtil>,
76        sdk_key: &str,
77    ) -> Self {
78        let sampling_key_set =
79            HashSetWithTTL::new(statsig_runtime, Duration::from_secs(TTL_IN_SECONDS));
80
81        Self {
82            sampling_key_set,
83            hashing,
84            global_configs: GlobalConfigs::get_instance(sdk_key),
85        }
86    }
87
88    pub async fn shutdown(&self, _timeout: Duration) -> Result<(), StatsigErr> {
89        self.sampling_key_set.shutdown().await;
90        Ok(())
91    }
92
93    pub fn get_sampling_decision_and_details(
94        &self,
95        user: &StatsigUserInternal,
96        eval_result: Option<&AnyEvaluation>,
97        parameter_name_for_layer: Option<&str>,
98    ) -> SamplingDecision {
99        let eval_result = match eval_result {
100            Some(result) => result,
101            None => return SamplingDecision::force_logged(),
102        };
103
104        if self.should_skip_sampling(eval_result, &user.statsig_environment) {
105            return SamplingDecision::force_logged();
106        }
107
108        let base_eval_res = eval_result.get_base_result();
109        let sampling_ttl_set_key = format!("{}_{}", base_eval_res.name, base_eval_res.rule_id);
110
111        if !self
112            .sampling_key_set
113            .contains(&sampling_ttl_set_key)
114            .unwrap_or(false)
115        {
116            let _ = self.sampling_key_set.add(sampling_ttl_set_key);
117            return SamplingDecision::force_logged();
118        }
119
120        let sampling_mode = self.get_sampling_mode();
121        let sampling_exposure_key =
122            self.compute_sampling_exposure_key(eval_result, user, parameter_name_for_layer);
123
124        let (should_send_exposures, sampling_rate) =
125            self.evaluate_exposure_sending(eval_result, &sampling_exposure_key);
126
127        let sampling_log_status = match sampling_rate {
128            None => SamplingStatus::None, // No sampling rate, no status
129            Some(_) if should_send_exposures => SamplingStatus::Logged,
130            Some(_) => SamplingStatus::Dropped,
131        };
132
133        match sampling_mode {
134            SamplingMode::On => SamplingDecision::new(
135                should_send_exposures,
136                sampling_rate,
137                sampling_log_status,
138                SamplingMode::On,
139            ),
140            SamplingMode::Shadow => SamplingDecision::new(
141                true,
142                sampling_rate,
143                sampling_log_status,
144                SamplingMode::Shadow,
145            ),
146            _ => SamplingDecision::force_logged(),
147        }
148    }
149
150    // -------------------------
151    //   Utils For Generating Sampling Related Exposure Key
152    // -------------------------
153
154    fn compute_sampling_exposure_key(
155        &self,
156        eval_result: &AnyEvaluation,
157        user: &StatsigUserInternal,
158        parameter_name_for_layer: Option<&str>,
159    ) -> String {
160        let base_eval_res = eval_result.get_base_result();
161
162        match eval_result {
163            AnyEvaluation::Layer(eval) => self.compute_sampling_key_for_layer(
164                &base_eval_res.name,
165                eval.allocated_experiment_name.as_deref().unwrap_or("null"),
166                parameter_name_for_layer.unwrap_or("null"),
167                &base_eval_res.rule_id,
168                user,
169            ),
170            _ => self.compute_sampling_key_for_gate_or_config(
171                &base_eval_res.name,
172                &base_eval_res.rule_id,
173                &eval_result.get_gate_bool_value(),
174                user,
175            ),
176        }
177    }
178
179    /// compute sampling key for gate / experiment / dynamic config
180    fn compute_sampling_key_for_gate_or_config(
181        &self,
182        name: &str,
183        rule_id: &str,
184        value: &bool,
185        user: &StatsigUserInternal,
186    ) -> String {
187        let user_key = self.compute_user_key(user);
188        format!("n:{name};u:{user_key};r:{rule_id};v:{value}")
189    }
190
191    /// compute sampling key for layers
192    fn compute_sampling_key_for_layer(
193        &self,
194        layer_name: &str,
195        experiment_name: &str,
196        parameter_name: &str,
197        rule_id: &str,
198        user: &StatsigUserInternal,
199    ) -> String {
200        let user_key = self.compute_user_key(user);
201        format!("n:{layer_name};e:{experiment_name};p:{parameter_name};u:{user_key};r:{rule_id}")
202    }
203
204    fn compute_user_key(&self, user: &StatsigUserInternal) -> String {
205        let user_data = &user.user_data;
206
207        let mut user_key = format!(
208            "u:{};",
209            user_data
210                .user_id
211                .as_ref()
212                .and_then(|id| id.string_value.as_deref())
213                .unwrap_or("")
214        );
215
216        if let Some(custom_ids) = user_data.custom_ids.as_ref() {
217            for (key, val) in custom_ids {
218                if let Some(string_value) = &val.string_value {
219                    user_key.push_str(&format!("{key}:{string_value};"));
220                }
221            }
222        };
223
224        user_key
225    }
226
227    // -------------------------
228    //   Other Helper Functions
229    // -------------------------
230
231    /// Returns a tuple:
232    /// - `bool`: Whether exposures should be sent.
233    /// - `Option<u64>`: The sampling rate used for the decision (if applicable).
234    fn evaluate_exposure_sending(
235        &self,
236        eval_result: &AnyEvaluation,
237        sampling_exposure_key: &String,
238    ) -> (bool, Option<u64>) {
239        let eval_base_res = eval_result.get_base_result();
240        let special_case_sampling_rate = self.get_special_case_sampling_rate();
241
242        if SPECIAL_CASE_RULES.contains(&eval_base_res.rule_id.as_str())
243            && special_case_sampling_rate.is_some()
244        {
245            if let Some(special_rate) = special_case_sampling_rate {
246                let should_send_exposures =
247                    self.is_hash_in_sampling_rate(sampling_exposure_key, special_rate);
248                return (should_send_exposures, Some(special_rate));
249            }
250        }
251
252        if let Some(rate) = eval_base_res
253            .sampling_info
254            .as_ref()
255            .and_then(|info| info.sampling_rate)
256        {
257            let should_send_exposures = self.is_hash_in_sampling_rate(sampling_exposure_key, rate);
258
259            return (should_send_exposures, Some(rate));
260        }
261
262        (true, None) // default to true, always send exposures, do NOT sample
263    }
264
265    fn should_skip_sampling(
266        &self,
267        eval_result: &AnyEvaluation,
268        env: &Option<HashMap<String, DynamicValue>>,
269    ) -> bool {
270        let sampling_mode = self.get_sampling_mode();
271
272        if matches!(sampling_mode, SamplingMode::None) {
273            return true;
274        }
275
276        let sampling_info = eval_result.get_base_result().sampling_info.as_ref();
277
278        if sampling_info
279            .and_then(|info| info.forward_all_exposures)
280            .unwrap_or(false)
281        {
282            return true;
283        }
284
285        if sampling_info
286            .and_then(|info| info.has_seen_analytical_gates)
287            .unwrap_or(false)
288        {
289            return true;
290        }
291
292        // skip sampling if env is not in production
293        if env
294            .as_ref()
295            .and_then(|e| e.get("tier"))
296            .and_then(|tier| tier.string_value.as_deref())
297            != Some("production")
298        {
299            return true;
300        }
301
302        let rule_id = &eval_result.get_base_result().rule_id;
303        if rule_id.ends_with(":override") || rule_id.ends_with(":id_override") {
304            return true;
305        }
306
307        false
308    }
309
310    fn is_hash_in_sampling_rate(&self, key: &String, sampling_rate: u64) -> bool {
311        let hash_value = self.hashing.sha256_to_u64(key);
312        hash_value % sampling_rate == 0
313    }
314
315    fn get_sampling_mode(&self) -> SamplingMode {
316        self.global_configs
317            .get_sdk_config_value("sampling_mode")
318            .and_then(|mode| mode.string_value)
319            .as_deref()
320            .map_or(SamplingMode::None, |mode_str| match mode_str {
321                "on" => SamplingMode::On,
322                "shadow" => SamplingMode::Shadow,
323                _ => SamplingMode::None,
324            })
325    }
326
327    fn get_special_case_sampling_rate(&self) -> Option<u64> {
328        self.global_configs
329            .get_sdk_config_value("special_case_sampling_rate")
330            .and_then(|v| v.float_value)
331            .map(|rate| rate as u64)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::evaluation::evaluation_types::{BaseEvaluation, GateEvaluation};
339    use crate::{SpecStore, SpecsSource, SpecsUpdate, StatsigUser};
340    use serde_json::Value;
341    use std::fs;
342    use std::sync::LazyLock;
343
344    static GATE: LazyLock<GateEvaluation> = LazyLock::new(|| GateEvaluation {
345        base: BaseEvaluation {
346            name: "publish_to_all".to_string(),
347            rule_id: "rule_id".to_string(),
348            secondary_exposures: vec![],
349            sampling_info: Default::default(),
350        },
351        id_type: String::new(),
352        value: false,
353    });
354
355    fn create_mock_user() -> StatsigUserInternal {
356        let mut custom_ids = HashMap::new();
357        custom_ids.insert("k1".to_string(), "v1".to_string());
358        custom_ids.insert("k2".to_string(), "v2".to_string());
359
360        StatsigUserInternal {
361            user_data: StatsigUser::with_custom_ids(custom_ids),
362            statsig_environment: Some(HashMap::from([(
363                "tier".to_string(),
364                DynamicValue {
365                    string_value: Some("development".to_string()),
366                    ..Default::default()
367                },
368            )])),
369        }
370    }
371
372    fn create_mock_evaluation_result() -> AnyEvaluation<'static> {
373        AnyEvaluation::FeatureGate(&GATE)
374    }
375
376    #[test]
377    fn test_should_skip_sampling() {
378        let file_path = "tests/data/dcs_with_sdk_configs.json";
379        let file_content = fs::read_to_string(file_path).expect("Unable to read file");
380        let json_data: Value = serde_json::from_str(&file_content).expect("Unable to parse JSON");
381
382        // Create the mocked SpecStore with sdk configs
383        let specs_update = SpecsUpdate {
384            data: json_data.to_string(),
385            source: SpecsSource::Network,
386            received_at: 2000,
387        };
388
389        let spec_store = SpecStore::default();
390        spec_store
391            .set_values(specs_update)
392            .expect("Set Specstore failed");
393
394        // initialize sampling processor
395        let runtime = StatsigRuntime::get_runtime();
396        let hashing = Arc::new(HashUtil::new());
397        let processor = SamplingProcessor::new(&runtime, hashing, "");
398
399        let mut test_user = create_mock_user();
400        let mock_evaluation_res = create_mock_evaluation_result();
401
402        // Should skip sampling in a non-production environment
403        let should_skip_sample =
404            processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
405        assert!(should_skip_sample);
406
407        test_user.statsig_environment = Some(HashMap::from([(
408            "tier".to_string(),
409            DynamicValue {
410                string_value: Some("production".to_string()),
411                ..Default::default()
412            },
413        )]));
414        // should not skip sampling in a production environment
415        let should_skip_sample_2 =
416            processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
417        assert!(!should_skip_sample_2);
418    }
419}