statsig_rust/
sampling_processor.rs

1use crate::evaluation::evaluation_types::AnyEvaluation;
2use crate::global_configs::GlobalConfigs;
3use crate::hashing::HashUtil;
4use crate::hashset_with_ttl::HashSetWithTTL;
5use crate::statsig_user_internal::StatsigUserInternal;
6use crate::{DynamicValue, StatsigErr, StatsigRuntime};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::time::Duration;
11
12const SPECIAL_CASE_RULES: [&str; 3] = ["disabled", "default", ""];
13const TTL_IN_SECONDS: u64 = 60;
14
15#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
16#[serde(rename_all = "lowercase")]
17pub enum SamplingStatus {
18    Logged,
19    Dropped,
20    #[default]
21    None,
22}
23
24#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26pub enum SamplingMode {
27    On,
28    Shadow,
29    #[default]
30    None,
31}
32
33#[derive(Default)]
34pub struct SamplingDecision {
35    pub should_send_exposure: bool,
36    pub sampling_rate: Option<u64>,
37    pub sampling_status: SamplingStatus,
38    pub sampling_mode: SamplingMode,
39}
40
41impl SamplingDecision {
42    pub fn new(
43        should_send_exposure: bool,
44        sampling_rate: Option<u64>,
45        sampling_status: SamplingStatus,
46        sampling_mode: SamplingMode,
47    ) -> Self {
48        Self {
49            should_send_exposure,
50            sampling_rate,
51            sampling_status,
52            sampling_mode,
53        }
54    }
55
56    pub fn force_logged() -> Self {
57        Self {
58            should_send_exposure: true,
59            sampling_rate: None,
60            sampling_status: SamplingStatus::None,
61            sampling_mode: SamplingMode::None,
62        }
63    }
64}
65
66pub struct SamplingProcessor {
67    sampling_key_set: HashSetWithTTL,
68    hashing: Arc<HashUtil>,
69    global_configs: Arc<GlobalConfigs>,
70}
71
72impl SamplingProcessor {
73    pub fn new(
74        statsig_runtime: &Arc<StatsigRuntime>,
75        hashing: Arc<HashUtil>,
76        sdk_key: &str,
77    ) -> Self {
78        let sampling_key_set =
79            HashSetWithTTL::new(statsig_runtime, Duration::from_secs(TTL_IN_SECONDS));
80
81        Self {
82            sampling_key_set,
83            hashing,
84            global_configs: GlobalConfigs::get_instance(sdk_key),
85        }
86    }
87
88    pub async fn shutdown(&self, _timeout: Duration) -> Result<(), StatsigErr> {
89        self.sampling_key_set.shutdown().await;
90        Ok(())
91    }
92
93    pub fn get_sampling_decision_and_details(
94        &self,
95        user: &StatsigUserInternal,
96        eval_result: Option<&AnyEvaluation>,
97        parameter_name_for_layer: Option<&str>,
98    ) -> SamplingDecision {
99        let eval_result = match eval_result {
100            Some(result) => result,
101            None => return SamplingDecision::force_logged(),
102        };
103
104        if self.should_skip_sampling(eval_result, &user.statsig_environment) {
105            return SamplingDecision::force_logged();
106        }
107
108        let base_eval_res = eval_result.get_base_result();
109        let sampling_ttl_set_key = format!("{}_{}", base_eval_res.name, base_eval_res.rule_id);
110
111        if !self
112            .sampling_key_set
113            .contains(&sampling_ttl_set_key)
114            .unwrap_or(false)
115        {
116            let _ = self.sampling_key_set.add(sampling_ttl_set_key);
117            return SamplingDecision::force_logged();
118        }
119
120        let sampling_mode = self.get_sampling_mode();
121        let sampling_exposure_key =
122            self.compute_sampling_exposure_key(eval_result, user, parameter_name_for_layer);
123
124        let (should_send_exposures, sampling_rate) =
125            self.evaluate_exposure_sending(eval_result, &sampling_exposure_key);
126
127        let sampling_log_status = match sampling_rate {
128            None => SamplingStatus::None, // No sampling rate, no status
129            Some(_) if should_send_exposures => SamplingStatus::Logged,
130            Some(_) => SamplingStatus::Dropped,
131        };
132
133        match sampling_mode {
134            SamplingMode::On => SamplingDecision::new(
135                should_send_exposures,
136                sampling_rate,
137                sampling_log_status,
138                SamplingMode::On,
139            ),
140            SamplingMode::Shadow => SamplingDecision::new(
141                true,
142                sampling_rate,
143                sampling_log_status,
144                SamplingMode::Shadow,
145            ),
146            _ => SamplingDecision::force_logged(),
147        }
148    }
149
150    // -------------------------
151    //   Utils For Generating Sampling Related Exposure Key
152    // -------------------------
153
154    fn compute_sampling_exposure_key(
155        &self,
156        eval_result: &AnyEvaluation,
157        user: &StatsigUserInternal,
158        parameter_name_for_layer: Option<&str>,
159    ) -> String {
160        let base_eval_res = eval_result.get_base_result();
161
162        match eval_result {
163            AnyEvaluation::Layer(eval) => self.compute_sampling_key_for_layer(
164                &base_eval_res.name,
165                eval.allocated_experiment_name.as_deref().unwrap_or("null"),
166                parameter_name_for_layer.unwrap_or("null"),
167                &base_eval_res.rule_id,
168                user,
169            ),
170            _ => self.compute_sampling_key_for_gate_or_config(
171                &base_eval_res.name,
172                &base_eval_res.rule_id,
173                &eval_result.get_gate_bool_value(),
174                user,
175            ),
176        }
177    }
178
179    /// compute sampling key for gate / experiment / dynamic config
180    fn compute_sampling_key_for_gate_or_config(
181        &self,
182        name: &str,
183        rule_id: &str,
184        value: &bool,
185        user: &StatsigUserInternal,
186    ) -> String {
187        let user_key = self.compute_user_key(user);
188        format!("n:{name};u:{user_key};r:{rule_id};v:{value}")
189    }
190
191    /// compute sampling key for layers
192    fn compute_sampling_key_for_layer(
193        &self,
194        layer_name: &str,
195        experiment_name: &str,
196        parameter_name: &str,
197        rule_id: &str,
198        user: &StatsigUserInternal,
199    ) -> String {
200        let user_key = self.compute_user_key(user);
201        format!("n:{layer_name};e:{experiment_name};p:{parameter_name};u:{user_key};r:{rule_id}")
202    }
203
204    fn compute_user_key(&self, user: &StatsigUserInternal) -> String {
205        let user_data = &user.user_data;
206
207        let mut user_key = format!(
208            "u:{};",
209            user_data
210                .user_id
211                .as_ref()
212                .and_then(|id| id.string_value.as_deref())
213                .unwrap_or("")
214        );
215
216        if let Some(custom_ids) = user_data.custom_ids.as_ref() {
217            for (key, val) in custom_ids {
218                if let Some(string_value) = &val.string_value {
219                    user_key.push_str(&format!("{key}:{string_value};"));
220                }
221            }
222        };
223
224        user_key
225    }
226
227    // -------------------------
228    //   Other Helper Functions
229    // -------------------------
230
231    /// Returns a tuple:
232    /// - `bool`: Whether exposures should be sent.
233    /// - `Option<u64>`: The sampling rate used for the decision (if applicable).
234    fn evaluate_exposure_sending(
235        &self,
236        eval_result: &AnyEvaluation,
237        sampling_exposure_key: &String,
238    ) -> (bool, Option<u64>) {
239        let eval_base_res = eval_result.get_base_result();
240        let special_case_sampling_rate = self.get_special_case_sampling_rate();
241
242        if SPECIAL_CASE_RULES.contains(&eval_base_res.rule_id.as_str())
243            && special_case_sampling_rate.is_some()
244        {
245            if let Some(special_rate) = special_case_sampling_rate {
246                let should_send_exposures =
247                    self.is_hash_in_sampling_rate(sampling_exposure_key, special_rate);
248                return (should_send_exposures, Some(special_rate));
249            }
250        }
251
252        if let Some(rate) = eval_base_res.sampling_rate {
253            let should_send_exposures = self.is_hash_in_sampling_rate(sampling_exposure_key, rate);
254
255            return (should_send_exposures, Some(rate));
256        }
257
258        (true, None) // default to true, always send exposures, do NOT sample
259    }
260
261    fn should_skip_sampling(
262        &self,
263        eval_result: &AnyEvaluation,
264        env: &Option<HashMap<String, DynamicValue>>,
265    ) -> bool {
266        let sampling_mode = self.get_sampling_mode();
267
268        if matches!(sampling_mode, SamplingMode::None) {
269            return true;
270        }
271
272        if eval_result
273            .get_base_result()
274            .forward_all_exposures
275            .unwrap_or(false)
276        {
277            return true;
278        }
279
280        // skip sampling if env is not in production
281        if env
282            .as_ref()
283            .and_then(|e| e.get("tier"))
284            .and_then(|tier| tier.string_value.as_deref())
285            != Some("production")
286        {
287            return true;
288        }
289
290        let rule_id = &eval_result.get_base_result().rule_id;
291        if rule_id.ends_with(":override") || rule_id.ends_with(":id_override") {
292            return true;
293        }
294
295        false
296    }
297
298    fn is_hash_in_sampling_rate(&self, key: &String, sampling_rate: u64) -> bool {
299        let hash_value = self.hashing.sha256_to_u64(key);
300        hash_value % sampling_rate == 0
301    }
302
303    fn get_sampling_mode(&self) -> SamplingMode {
304        self.global_configs
305            .get_sdk_config_value("sampling_mode")
306            .and_then(|mode| mode.string_value)
307            .as_deref()
308            .map_or(SamplingMode::None, |mode_str| match mode_str {
309                "on" => SamplingMode::On,
310                "shadow" => SamplingMode::Shadow,
311                _ => SamplingMode::None,
312            })
313    }
314
315    fn get_special_case_sampling_rate(&self) -> Option<u64> {
316        self.global_configs
317            .get_sdk_config_value("special_case_sampling_rate")
318            .and_then(|v| v.float_value)
319            .map(|rate| rate as u64)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::evaluation::evaluation_types::{BaseEvaluation, GateEvaluation};
327    use crate::{SpecStore, SpecsSource, SpecsUpdate, StatsigUser};
328    use serde_json::Value;
329    use std::fs;
330    use std::sync::LazyLock;
331
332    static GATE: LazyLock<GateEvaluation> = LazyLock::new(|| GateEvaluation {
333        base: BaseEvaluation {
334            name: "publish_to_all".to_string(),
335            rule_id: "rule_id".to_string(),
336            secondary_exposures: vec![],
337            sampling_rate: None,
338            forward_all_exposures: Some(false),
339        },
340        id_type: String::new(),
341        value: false,
342    });
343
344    fn create_mock_user() -> StatsigUserInternal {
345        let mut custom_ids = HashMap::new();
346        custom_ids.insert("k1".to_string(), "v1".to_string());
347        custom_ids.insert("k2".to_string(), "v2".to_string());
348
349        StatsigUserInternal {
350            user_data: StatsigUser::with_custom_ids(custom_ids),
351            statsig_environment: Some(HashMap::from([(
352                "tier".to_string(),
353                DynamicValue {
354                    string_value: Some("development".to_string()),
355                    ..Default::default()
356                },
357            )])),
358        }
359    }
360
361    fn create_mock_evaluation_result() -> AnyEvaluation<'static> {
362        AnyEvaluation::FeatureGate(&GATE)
363    }
364
365    #[test]
366    fn test_should_skip_sampling() {
367        let file_path = "tests/data/dcs_with_sdk_configs.json";
368        let file_content = fs::read_to_string(file_path).expect("Unable to read file");
369        let json_data: Value = serde_json::from_str(&file_content).expect("Unable to parse JSON");
370
371        // Create the mocked SpecStore with sdk configs
372        let specs_update = SpecsUpdate {
373            data: json_data.to_string(),
374            source: SpecsSource::Network,
375            received_at: 2000,
376        };
377
378        let spec_store = SpecStore::default();
379        spec_store
380            .set_values(specs_update)
381            .expect("Set Specstore failed");
382
383        // initialize sampling processor
384        let runtime = StatsigRuntime::get_runtime();
385        let hashing = Arc::new(HashUtil::new());
386        let processor = SamplingProcessor::new(&runtime, hashing, "");
387
388        let mut test_user = create_mock_user();
389        let mock_evaluation_res = create_mock_evaluation_result();
390
391        // Should skip sampling in a non-production environment
392        let should_skip_sample =
393            processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
394        assert!(should_skip_sample);
395
396        test_user.statsig_environment = Some(HashMap::from([(
397            "tier".to_string(),
398            DynamicValue {
399                string_value: Some("production".to_string()),
400                ..Default::default()
401            },
402        )]));
403        // should not skip sampling in a production environment
404        let should_skip_sample_2 =
405            processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
406        assert!(!should_skip_sample_2);
407    }
408}