statsig_rust/
sampling_processor.rs

1use crate::evaluation::evaluation_types::AnyEvaluation;
2use crate::global_configs::GlobalConfigs;
3use crate::hashing::HashUtil;
4use crate::hashset_with_ttl::HashSetWithTTL;
5use crate::statsig_user_internal::StatsigUserInternal;
6use crate::{StatsigErr, StatsigRuntime};
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use tokio::time::Duration;
10
11const SPECIAL_CASE_RULES: [&str; 3] = ["disabled", "default", ""];
12const TTL_IN_SECONDS: u64 = 60;
13
14#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum SamplingStatus {
17    Logged,
18    Dropped,
19    #[default]
20    None,
21}
22
23#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
24#[serde(rename_all = "lowercase")]
25pub enum SamplingMode {
26    On,
27    Shadow,
28    #[default]
29    None,
30}
31
32#[derive(Default)]
33pub struct SamplingDecision {
34    pub should_send_exposure: bool,
35    pub sampling_rate: Option<u64>,
36    pub sampling_status: SamplingStatus,
37    pub sampling_mode: SamplingMode,
38}
39
40impl SamplingDecision {
41    pub fn new(
42        should_send_exposure: bool,
43        sampling_rate: Option<u64>,
44        sampling_status: SamplingStatus,
45        sampling_mode: SamplingMode,
46    ) -> Self {
47        Self {
48            should_send_exposure,
49            sampling_rate,
50            sampling_status,
51            sampling_mode,
52        }
53    }
54
55    pub fn force_logged() -> Self {
56        Self {
57            should_send_exposure: true,
58            sampling_rate: None,
59            sampling_status: SamplingStatus::None,
60            sampling_mode: SamplingMode::None,
61        }
62    }
63}
64
65pub struct SamplingProcessor {
66    sampling_key_set: HashSetWithTTL,
67    hashing: Arc<HashUtil>,
68    global_configs: Arc<GlobalConfigs>,
69}
70
71impl SamplingProcessor {
72    pub fn new(
73        statsig_runtime: &Arc<StatsigRuntime>,
74        hashing: Arc<HashUtil>,
75        sdk_key: &str,
76    ) -> Self {
77        let sampling_key_set =
78            HashSetWithTTL::new(statsig_runtime, Duration::from_secs(TTL_IN_SECONDS));
79
80        Self {
81            sampling_key_set,
82            hashing,
83            global_configs: GlobalConfigs::get_instance(sdk_key),
84        }
85    }
86
87    pub async fn shutdown(&self, _timeout: Duration) -> Result<(), StatsigErr> {
88        self.sampling_key_set.shutdown().await;
89        Ok(())
90    }
91
92    pub fn get_sampling_decision_and_details(
93        &self,
94        user: &StatsigUserInternal,
95        eval_result: Option<&AnyEvaluation>,
96        parameter_name_for_layer: Option<&str>,
97    ) -> SamplingDecision {
98        let eval_result = match eval_result {
99            Some(result) => result,
100            None => return SamplingDecision::force_logged(),
101        };
102
103        if self.should_skip_sampling(eval_result) {
104            return SamplingDecision::force_logged();
105        }
106
107        let base_eval_res = eval_result.get_base_result();
108        let sampling_ttl_set_key = format!("{}_{}", base_eval_res.name, base_eval_res.rule_id);
109
110        if !self
111            .sampling_key_set
112            .contains(&sampling_ttl_set_key)
113            .unwrap_or(false)
114        {
115            let _ = self.sampling_key_set.add(sampling_ttl_set_key);
116            return SamplingDecision::force_logged();
117        }
118
119        let sampling_mode = self.get_sampling_mode();
120        let sampling_exposure_key =
121            self.compute_sampling_exposure_key(eval_result, user, parameter_name_for_layer);
122
123        let (should_send_exposures, sampling_rate) =
124            self.evaluate_exposure_sending(eval_result, &sampling_exposure_key);
125
126        let sampling_log_status = match sampling_rate {
127            None => SamplingStatus::None, // No sampling rate, no status
128            Some(_) if should_send_exposures => SamplingStatus::Logged,
129            Some(_) => SamplingStatus::Dropped,
130        };
131
132        match sampling_mode {
133            SamplingMode::On => SamplingDecision::new(
134                should_send_exposures,
135                sampling_rate,
136                sampling_log_status,
137                SamplingMode::On,
138            ),
139            SamplingMode::Shadow => SamplingDecision::new(
140                true,
141                sampling_rate,
142                sampling_log_status,
143                SamplingMode::Shadow,
144            ),
145            _ => SamplingDecision::force_logged(),
146        }
147    }
148
149    // -------------------------
150    //   Utils For Generating Sampling Related Exposure Key
151    // -------------------------
152
153    fn compute_sampling_exposure_key(
154        &self,
155        eval_result: &AnyEvaluation,
156        user: &StatsigUserInternal,
157        parameter_name_for_layer: Option<&str>,
158    ) -> String {
159        let base_eval_res = eval_result.get_base_result();
160
161        match eval_result {
162            AnyEvaluation::Layer(eval) => self.compute_sampling_key_for_layer(
163                &base_eval_res.name,
164                eval.allocated_experiment_name.as_deref().unwrap_or("null"),
165                parameter_name_for_layer.unwrap_or("null"),
166                &base_eval_res.rule_id,
167                user,
168            ),
169            _ => self.compute_sampling_key_for_gate_or_config(
170                &base_eval_res.name,
171                &base_eval_res.rule_id,
172                &eval_result.get_gate_bool_value(),
173                user,
174            ),
175        }
176    }
177
178    /// compute sampling key for gate / experiment / dynamic config
179    fn compute_sampling_key_for_gate_or_config(
180        &self,
181        name: &str,
182        rule_id: &str,
183        value: &bool,
184        user: &StatsigUserInternal,
185    ) -> String {
186        let user_key = self.compute_user_key(user);
187        format!("n:{name};u:{user_key};r:{rule_id};v:{value}")
188    }
189
190    /// compute sampling key for layers
191    fn compute_sampling_key_for_layer(
192        &self,
193        layer_name: &str,
194        experiment_name: &str,
195        parameter_name: &str,
196        rule_id: &str,
197        user: &StatsigUserInternal,
198    ) -> String {
199        let user_key = self.compute_user_key(user);
200        format!("n:{layer_name};e:{experiment_name};p:{parameter_name};u:{user_key};r:{rule_id}")
201    }
202
203    fn compute_user_key(&self, user: &StatsigUserInternal) -> String {
204        let user_data = &user.user_data;
205
206        let mut user_key = format!(
207            "u:{};",
208            user_data
209                .user_id
210                .as_ref()
211                .and_then(|id| id.string_value.as_deref())
212                .unwrap_or("")
213        );
214
215        if let Some(custom_ids) = user_data.custom_ids.as_ref() {
216            for (key, val) in custom_ids {
217                if let Some(string_value) = &val.string_value {
218                    user_key.push_str(&format!("{key}:{string_value};"));
219                }
220            }
221        };
222
223        user_key
224    }
225
226    // -------------------------
227    //   Other Helper Functions
228    // -------------------------
229
230    /// Returns a tuple:
231    /// - `bool`: Whether exposures should be sent.
232    /// - `Option<u64>`: The sampling rate used for the decision (if applicable).
233    fn evaluate_exposure_sending(
234        &self,
235        eval_result: &AnyEvaluation,
236        sampling_exposure_key: &String,
237    ) -> (bool, Option<u64>) {
238        let eval_base_res = eval_result.get_base_result();
239        let special_case_sampling_rate = self.get_special_case_sampling_rate();
240
241        if SPECIAL_CASE_RULES.contains(&eval_base_res.rule_id.as_str())
242            && special_case_sampling_rate.is_some()
243        {
244            if let Some(special_rate) = special_case_sampling_rate {
245                let should_send_exposures =
246                    self.is_hash_in_sampling_rate(sampling_exposure_key, special_rate);
247                return (should_send_exposures, Some(special_rate));
248            }
249        }
250
251        if let Some(rate) = eval_base_res
252            .sampling_info
253            .as_ref()
254            .and_then(|info| info.sampling_rate)
255        {
256            let should_send_exposures = self.is_hash_in_sampling_rate(sampling_exposure_key, rate);
257
258            return (should_send_exposures, Some(rate));
259        }
260
261        (true, None) // default to true, always send exposures, do NOT sample
262    }
263
264    fn should_skip_sampling(&self, eval_result: &AnyEvaluation) -> bool {
265        let sampling_mode = self.get_sampling_mode();
266
267        if matches!(sampling_mode, SamplingMode::None) {
268            return true;
269        }
270
271        let sampling_info = eval_result.get_base_result().sampling_info.as_ref();
272
273        if sampling_info
274            .and_then(|info| info.forward_all_exposures)
275            .unwrap_or(false)
276        {
277            return true;
278        }
279
280        if sampling_info
281            .and_then(|info| info.has_seen_analytical_gates)
282            .unwrap_or(false)
283        {
284            return true;
285        }
286
287        false
288    }
289
290    fn is_hash_in_sampling_rate(&self, key: &String, sampling_rate: u64) -> bool {
291        let hash_value = self.hashing.sha256_to_u64(key);
292        hash_value % sampling_rate == 0
293    }
294
295    fn get_sampling_mode(&self) -> SamplingMode {
296        self.global_configs
297            .get_sdk_config_value("sampling_mode")
298            .and_then(|mode| mode.string_value)
299            .as_deref()
300            .map_or(SamplingMode::None, |mode_str| match mode_str {
301                "on" => SamplingMode::On,
302                "shadow" => SamplingMode::Shadow,
303                _ => SamplingMode::None,
304            })
305    }
306
307    fn get_special_case_sampling_rate(&self) -> Option<u64> {
308        self.global_configs
309            .get_sdk_config_value("special_case_sampling_rate")
310            .and_then(|v| v.float_value)
311            .map(|rate| rate as u64)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::evaluation::evaluation_types::{BaseEvaluation, GateEvaluation};
319    use crate::{DynamicValue, StatsigUser};
320    use std::collections::HashMap;
321    use std::sync::LazyLock;
322
323    static GATE: LazyLock<GateEvaluation> = LazyLock::new(|| GateEvaluation {
324        base: BaseEvaluation {
325            name: "publish_to_all".to_string(),
326            rule_id: "rule_id".to_string(),
327            secondary_exposures: vec![],
328            sampling_info: Default::default(),
329        },
330        id_type: String::new(),
331        value: false,
332    });
333
334    fn create_mock_user() -> StatsigUserInternal {
335        let mut custom_ids = HashMap::new();
336        custom_ids.insert("k1".to_string(), "v1".to_string());
337        custom_ids.insert("k2".to_string(), "v2".to_string());
338
339        StatsigUserInternal {
340            user_data: StatsigUser::with_custom_ids(custom_ids),
341            statsig_environment: Some(HashMap::from([(
342                "tier".to_string(),
343                DynamicValue {
344                    string_value: Some("development".to_string()),
345                    ..Default::default()
346                },
347            )])),
348        }
349    }
350
351    fn create_mock_evaluation_result() -> AnyEvaluation<'static> {
352        AnyEvaluation::FeatureGate(&GATE)
353    }
354}