1use crate::evaluation::evaluation_types::AnyEvaluation;
2use crate::global_configs::GlobalConfigs;
3use crate::hashing::HashUtil;
4use crate::hashset_with_ttl::HashSetWithTTL;
5use crate::statsig_user_internal::StatsigUserInternal;
6use crate::{DynamicValue, StatsigErr, StatsigRuntime};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::time::Duration;
11
12const SPECIAL_CASE_RULES: [&str; 3] = ["disabled", "default", ""];
13const TTL_IN_SECONDS: u64 = 60;
14
15#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
16#[serde(rename_all = "lowercase")]
17pub enum SamplingStatus {
18 Logged,
19 Dropped,
20 #[default]
21 None,
22}
23
24#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26pub enum SamplingMode {
27 On,
28 Shadow,
29 #[default]
30 None,
31}
32
33#[derive(Default)]
34pub struct SamplingDecision {
35 pub should_send_exposure: bool,
36 pub sampling_rate: Option<u64>,
37 pub sampling_status: SamplingStatus,
38 pub sampling_mode: SamplingMode,
39}
40
41impl SamplingDecision {
42 pub fn new(
43 should_send_exposure: bool,
44 sampling_rate: Option<u64>,
45 sampling_status: SamplingStatus,
46 sampling_mode: SamplingMode,
47 ) -> Self {
48 Self {
49 should_send_exposure,
50 sampling_rate,
51 sampling_status,
52 sampling_mode,
53 }
54 }
55
56 pub fn force_logged() -> Self {
57 Self {
58 should_send_exposure: true,
59 sampling_rate: None,
60 sampling_status: SamplingStatus::None,
61 sampling_mode: SamplingMode::None,
62 }
63 }
64}
65
66pub struct SamplingProcessor {
67 sampling_key_set: HashSetWithTTL,
68 hashing: Arc<HashUtil>,
69 global_configs: Arc<GlobalConfigs>,
70}
71
72impl SamplingProcessor {
73 pub fn new(
74 statsig_runtime: &Arc<StatsigRuntime>,
75 hashing: Arc<HashUtil>,
76 sdk_key: &str,
77 ) -> Self {
78 let sampling_key_set =
79 HashSetWithTTL::new(statsig_runtime, Duration::from_secs(TTL_IN_SECONDS));
80
81 Self {
82 sampling_key_set,
83 hashing,
84 global_configs: GlobalConfigs::get_instance(sdk_key),
85 }
86 }
87
88 pub async fn shutdown(&self, _timeout: Duration) -> Result<(), StatsigErr> {
89 self.sampling_key_set.shutdown().await;
90 Ok(())
91 }
92
93 pub fn get_sampling_decision_and_details(
94 &self,
95 user: &StatsigUserInternal,
96 eval_result: Option<&AnyEvaluation>,
97 parameter_name_for_layer: Option<&str>,
98 ) -> SamplingDecision {
99 let eval_result = match eval_result {
100 Some(result) => result,
101 None => return SamplingDecision::force_logged(),
102 };
103
104 if self.should_skip_sampling(eval_result, &user.statsig_environment) {
105 return SamplingDecision::force_logged();
106 }
107
108 let base_eval_res = eval_result.get_base_result();
109 let sampling_ttl_set_key = format!("{}_{}", base_eval_res.name, base_eval_res.rule_id);
110
111 if !self
112 .sampling_key_set
113 .contains(&sampling_ttl_set_key)
114 .unwrap_or(false)
115 {
116 let _ = self.sampling_key_set.add(sampling_ttl_set_key);
117 return SamplingDecision::force_logged();
118 }
119
120 let sampling_mode = self.get_sampling_mode();
121 let sampling_exposure_key =
122 self.compute_sampling_exposure_key(eval_result, user, parameter_name_for_layer);
123
124 let (should_send_exposures, sampling_rate) =
125 self.evaluate_exposure_sending(eval_result, &sampling_exposure_key);
126
127 let sampling_log_status = match sampling_rate {
128 None => SamplingStatus::None, Some(_) if should_send_exposures => SamplingStatus::Logged,
130 Some(_) => SamplingStatus::Dropped,
131 };
132
133 match sampling_mode {
134 SamplingMode::On => SamplingDecision::new(
135 should_send_exposures,
136 sampling_rate,
137 sampling_log_status,
138 SamplingMode::On,
139 ),
140 SamplingMode::Shadow => SamplingDecision::new(
141 true,
142 sampling_rate,
143 sampling_log_status,
144 SamplingMode::Shadow,
145 ),
146 _ => SamplingDecision::force_logged(),
147 }
148 }
149
150 fn compute_sampling_exposure_key(
155 &self,
156 eval_result: &AnyEvaluation,
157 user: &StatsigUserInternal,
158 parameter_name_for_layer: Option<&str>,
159 ) -> String {
160 let base_eval_res = eval_result.get_base_result();
161
162 match eval_result {
163 AnyEvaluation::Layer(eval) => self.compute_sampling_key_for_layer(
164 &base_eval_res.name,
165 eval.allocated_experiment_name.as_deref().unwrap_or("null"),
166 parameter_name_for_layer.unwrap_or("null"),
167 &base_eval_res.rule_id,
168 user,
169 ),
170 _ => self.compute_sampling_key_for_gate_or_config(
171 &base_eval_res.name,
172 &base_eval_res.rule_id,
173 &eval_result.get_gate_bool_value(),
174 user,
175 ),
176 }
177 }
178
179 fn compute_sampling_key_for_gate_or_config(
181 &self,
182 name: &str,
183 rule_id: &str,
184 value: &bool,
185 user: &StatsigUserInternal,
186 ) -> String {
187 let user_key = self.compute_user_key(user);
188 format!("n:{name};u:{user_key};r:{rule_id};v:{value}")
189 }
190
191 fn compute_sampling_key_for_layer(
193 &self,
194 layer_name: &str,
195 experiment_name: &str,
196 parameter_name: &str,
197 rule_id: &str,
198 user: &StatsigUserInternal,
199 ) -> String {
200 let user_key = self.compute_user_key(user);
201 format!("n:{layer_name};e:{experiment_name};p:{parameter_name};u:{user_key};r:{rule_id}")
202 }
203
204 fn compute_user_key(&self, user: &StatsigUserInternal) -> String {
205 let user_data = &user.user_data;
206
207 let mut user_key = format!(
208 "u:{};",
209 user_data
210 .user_id
211 .as_ref()
212 .and_then(|id| id.string_value.as_deref())
213 .unwrap_or("")
214 );
215
216 if let Some(custom_ids) = user_data.custom_ids.as_ref() {
217 for (key, val) in custom_ids {
218 if let Some(string_value) = &val.string_value {
219 user_key.push_str(&format!("{key}:{string_value};"));
220 }
221 }
222 };
223
224 user_key
225 }
226
227 fn evaluate_exposure_sending(
235 &self,
236 eval_result: &AnyEvaluation,
237 sampling_exposure_key: &String,
238 ) -> (bool, Option<u64>) {
239 let eval_base_res = eval_result.get_base_result();
240 let special_case_sampling_rate = self.get_special_case_sampling_rate();
241
242 if SPECIAL_CASE_RULES.contains(&eval_base_res.rule_id.as_str())
243 && special_case_sampling_rate.is_some()
244 {
245 if let Some(special_rate) = special_case_sampling_rate {
246 let should_send_exposures =
247 self.is_hash_in_sampling_rate(sampling_exposure_key, special_rate);
248 return (should_send_exposures, Some(special_rate));
249 }
250 }
251
252 if let Some(rate) = eval_base_res
253 .sampling_info
254 .as_ref()
255 .and_then(|info| info.sampling_rate)
256 {
257 let should_send_exposures = self.is_hash_in_sampling_rate(sampling_exposure_key, rate);
258
259 return (should_send_exposures, Some(rate));
260 }
261
262 (true, None) }
264
265 fn should_skip_sampling(
266 &self,
267 eval_result: &AnyEvaluation,
268 env: &Option<HashMap<String, DynamicValue>>,
269 ) -> bool {
270 let sampling_mode = self.get_sampling_mode();
271
272 if matches!(sampling_mode, SamplingMode::None) {
273 return true;
274 }
275
276 let sampling_info = eval_result.get_base_result().sampling_info.as_ref();
277
278 if sampling_info
279 .and_then(|info| info.forward_all_exposures)
280 .unwrap_or(false)
281 {
282 return true;
283 }
284
285 if sampling_info
286 .and_then(|info| info.has_seen_analytical_gates)
287 .unwrap_or(false)
288 {
289 return true;
290 }
291
292 if env
294 .as_ref()
295 .and_then(|e| e.get("tier"))
296 .and_then(|tier| tier.string_value.as_deref())
297 != Some("production")
298 {
299 return true;
300 }
301
302 let rule_id = &eval_result.get_base_result().rule_id;
303 if rule_id.ends_with(":override") || rule_id.ends_with(":id_override") {
304 return true;
305 }
306
307 false
308 }
309
310 fn is_hash_in_sampling_rate(&self, key: &String, sampling_rate: u64) -> bool {
311 let hash_value = self.hashing.sha256_to_u64(key);
312 hash_value % sampling_rate == 0
313 }
314
315 fn get_sampling_mode(&self) -> SamplingMode {
316 self.global_configs
317 .get_sdk_config_value("sampling_mode")
318 .and_then(|mode| mode.string_value)
319 .as_deref()
320 .map_or(SamplingMode::None, |mode_str| match mode_str {
321 "on" => SamplingMode::On,
322 "shadow" => SamplingMode::Shadow,
323 _ => SamplingMode::None,
324 })
325 }
326
327 fn get_special_case_sampling_rate(&self) -> Option<u64> {
328 self.global_configs
329 .get_sdk_config_value("special_case_sampling_rate")
330 .and_then(|v| v.float_value)
331 .map(|rate| rate as u64)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::evaluation::evaluation_types::{BaseEvaluation, GateEvaluation};
339 use crate::{SpecStore, SpecsSource, SpecsUpdate, StatsigUser};
340 use serde_json::Value;
341 use std::fs;
342 use std::sync::LazyLock;
343
344 static GATE: LazyLock<GateEvaluation> = LazyLock::new(|| GateEvaluation {
345 base: BaseEvaluation {
346 name: "publish_to_all".to_string(),
347 rule_id: "rule_id".to_string(),
348 secondary_exposures: vec![],
349 sampling_info: Default::default(),
350 },
351 id_type: String::new(),
352 value: false,
353 });
354
355 fn create_mock_user() -> StatsigUserInternal {
356 let mut custom_ids = HashMap::new();
357 custom_ids.insert("k1".to_string(), "v1".to_string());
358 custom_ids.insert("k2".to_string(), "v2".to_string());
359
360 StatsigUserInternal {
361 user_data: StatsigUser::with_custom_ids(custom_ids),
362 statsig_environment: Some(HashMap::from([(
363 "tier".to_string(),
364 DynamicValue {
365 string_value: Some("development".to_string()),
366 ..Default::default()
367 },
368 )])),
369 }
370 }
371
372 fn create_mock_evaluation_result() -> AnyEvaluation<'static> {
373 AnyEvaluation::FeatureGate(&GATE)
374 }
375
376 #[test]
377 fn test_should_skip_sampling() {
378 let file_path = "tests/data/dcs_with_sdk_configs.json";
379 let file_content = fs::read_to_string(file_path).expect("Unable to read file");
380 let json_data: Value = serde_json::from_str(&file_content).expect("Unable to parse JSON");
381
382 let specs_update = SpecsUpdate {
384 data: json_data.to_string(),
385 source: SpecsSource::Network,
386 received_at: 2000,
387 };
388
389 let spec_store = SpecStore::default();
390 spec_store
391 .set_values(specs_update)
392 .expect("Set Specstore failed");
393
394 let runtime = StatsigRuntime::get_runtime();
396 let hashing = Arc::new(HashUtil::new());
397 let processor = SamplingProcessor::new(&runtime, hashing, "");
398
399 let mut test_user = create_mock_user();
400 let mock_evaluation_res = create_mock_evaluation_result();
401
402 let should_skip_sample =
404 processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
405 assert!(should_skip_sample);
406
407 test_user.statsig_environment = Some(HashMap::from([(
408 "tier".to_string(),
409 DynamicValue {
410 string_value: Some("production".to_string()),
411 ..Default::default()
412 },
413 )]));
414 let should_skip_sample_2 =
416 processor.should_skip_sampling(&mock_evaluation_res, &test_user.statsig_environment);
417 assert!(!should_skip_sample_2);
418 }
419}