1use crate::evaluation::evaluation_types::AnyEvaluation;
2use crate::global_configs::GlobalConfigs;
3use crate::hashing::HashUtil;
4use crate::hashset_with_ttl::HashSetWithTTL;
5use crate::{DynamicValue, StatsigErr, StatsigRuntime};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::time::Duration;
9
10const SPECIAL_CASE_RULES: [&str; 3] = ["disabled", "default", ""];
11const TTL_IN_SECONDS: u64 = 60;
12
13#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum SamplingStatus {
16 Logged,
17 Dropped,
18 #[default]
19 None,
20}
21
22#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum SamplingMode {
25 On,
26 Shadow,
27 #[default]
28 None,
29}
30
31#[derive(Default)]
32pub struct SamplingDecision {
33 pub should_send_exposure: bool,
34 pub sampling_rate: Option<u64>,
35 pub sampling_status: SamplingStatus,
36 pub sampling_mode: SamplingMode,
37}
38
39impl SamplingDecision {
40 pub fn new(
41 should_send_exposure: bool,
42 sampling_rate: Option<u64>,
43 sampling_status: SamplingStatus,
44 sampling_mode: SamplingMode,
45 ) -> Self {
46 Self {
47 should_send_exposure,
48 sampling_rate,
49 sampling_status,
50 sampling_mode,
51 }
52 }
53
54 pub fn force_logged() -> Self {
55 Self {
56 should_send_exposure: true,
57 sampling_rate: None,
58 sampling_status: SamplingStatus::None,
59 sampling_mode: SamplingMode::None,
60 }
61 }
62}
63
64pub struct SamplingProcessor {
65 sampling_key_set: HashSetWithTTL,
66 hashing: Arc<HashUtil>,
67 global_configs: Arc<GlobalConfigs>,
68}
69
70impl SamplingProcessor {
71 pub fn new(
72 statsig_runtime: &Arc<StatsigRuntime>,
73 hashing: Arc<HashUtil>,
74 sdk_key: &str,
75 ) -> Self {
76 let sampling_key_set =
77 HashSetWithTTL::new(statsig_runtime, Duration::from_secs(TTL_IN_SECONDS));
78
79 Self {
80 sampling_key_set,
81 hashing,
82 global_configs: GlobalConfigs::get_instance(sdk_key),
83 }
84 }
85
86 pub async fn shutdown(&self, _timeout: Duration) -> Result<(), StatsigErr> {
87 self.sampling_key_set.shutdown().await;
88 Ok(())
89 }
90
91 pub fn get_sampling_decision_and_details(
92 &self,
93 user_sampling_key: &str,
94 eval_result: Option<&AnyEvaluation>,
95 parameter_name_for_layer: Option<&str>,
96 ) -> SamplingDecision {
97 let eval_result = match eval_result {
98 Some(result) => result,
99 None => return SamplingDecision::force_logged(),
100 };
101
102 if self.should_skip_sampling(eval_result) {
103 return SamplingDecision::force_logged();
104 }
105
106 let base_eval_res = eval_result.get_base_result();
107 let sampling_ttl_set_key = format!("{}_{}", base_eval_res.name, base_eval_res.rule_id);
108
109 if !self
110 .sampling_key_set
111 .contains(&sampling_ttl_set_key)
112 .unwrap_or(false)
113 {
114 let _ = self.sampling_key_set.add(sampling_ttl_set_key);
115 return SamplingDecision::force_logged();
116 }
117
118 let sampling_mode = self.get_sampling_mode();
119 let sampling_exposure_key = self.compute_sampling_exposure_key(
120 eval_result,
121 user_sampling_key,
122 parameter_name_for_layer,
123 );
124
125 let (should_send_exposures, sampling_rate) =
126 self.evaluate_exposure_sending(eval_result, &sampling_exposure_key);
127
128 let sampling_log_status = match sampling_rate {
129 None => SamplingStatus::None, Some(_) if should_send_exposures => SamplingStatus::Logged,
131 Some(_) => SamplingStatus::Dropped,
132 };
133
134 match sampling_mode {
135 SamplingMode::On => SamplingDecision::new(
136 should_send_exposures,
137 sampling_rate,
138 sampling_log_status,
139 SamplingMode::On,
140 ),
141 SamplingMode::Shadow => SamplingDecision::new(
142 true,
143 sampling_rate,
144 sampling_log_status,
145 SamplingMode::Shadow,
146 ),
147 _ => SamplingDecision::force_logged(),
148 }
149 }
150
151 fn compute_sampling_exposure_key(
156 &self,
157 eval_result: &AnyEvaluation,
158 user_sampling_key: &str,
159 parameter_name_for_layer: Option<&str>,
160 ) -> String {
161 let base_eval_res = eval_result.get_base_result();
162
163 match eval_result {
164 AnyEvaluation::Layer(eval) => self.compute_sampling_key_for_layer(
165 &base_eval_res.name,
166 eval.allocated_experiment_name.as_deref().unwrap_or("null"),
167 parameter_name_for_layer.unwrap_or("null"),
168 &base_eval_res.rule_id,
169 user_sampling_key,
170 ),
171 _ => self.compute_sampling_key_for_gate_or_config(
172 &base_eval_res.name,
173 &base_eval_res.rule_id,
174 &eval_result.get_gate_bool_value(),
175 user_sampling_key,
176 ),
177 }
178 }
179
180 fn compute_sampling_key_for_gate_or_config(
182 &self,
183 name: &str,
184 rule_id: &str,
185 value: &bool,
186 user_sampling_key: &str,
187 ) -> String {
188 format!("n:{name};u:{user_sampling_key};r:{rule_id};v:{value}")
189 }
190
191 fn compute_sampling_key_for_layer(
193 &self,
194 layer_name: &str,
195 experiment_name: &str,
196 parameter_name: &str,
197 rule_id: &str,
198 user_sampling_key: &str,
199 ) -> String {
200 format!("n:{layer_name};e:{experiment_name};p:{parameter_name};u:{user_sampling_key};r:{rule_id}")
201 }
202
203 fn evaluate_exposure_sending(
211 &self,
212 eval_result: &AnyEvaluation,
213 sampling_exposure_key: &str,
214 ) -> (bool, Option<u64>) {
215 let eval_base_res = eval_result.get_base_result();
216 let special_case_sampling_rate = self.get_special_case_sampling_rate();
217
218 if SPECIAL_CASE_RULES.contains(&eval_base_res.rule_id.as_str())
219 && special_case_sampling_rate.is_some()
220 {
221 if let Some(special_rate) = special_case_sampling_rate {
222 let should_send_exposures =
223 self.is_hash_in_sampling_rate(sampling_exposure_key, special_rate);
224 return (should_send_exposures, Some(special_rate));
225 }
226 }
227
228 if let Some(rate) = eval_base_res
229 .sampling_info
230 .as_ref()
231 .and_then(|info| info.sampling_rate)
232 {
233 let should_send_exposures = self.is_hash_in_sampling_rate(sampling_exposure_key, rate);
234
235 return (should_send_exposures, Some(rate));
236 }
237
238 (true, None) }
240
241 fn should_skip_sampling(&self, eval_result: &AnyEvaluation) -> bool {
242 let sampling_mode = self.get_sampling_mode();
243
244 if matches!(sampling_mode, SamplingMode::None) {
245 return true;
246 }
247
248 let sampling_info = eval_result.get_base_result().sampling_info.as_ref();
249
250 if sampling_info
251 .and_then(|info| info.forward_all_exposures)
252 .unwrap_or(false)
253 {
254 return true;
255 }
256
257 if sampling_info
258 .and_then(|info| info.has_seen_analytical_gates)
259 .unwrap_or(false)
260 {
261 return true;
262 }
263
264 false
265 }
266
267 fn is_hash_in_sampling_rate(&self, key: &str, sampling_rate: u64) -> bool {
268 let hash_value = self.hashing.sha256_to_u64(key);
269 hash_value % sampling_rate == 0
270 }
271
272 fn get_sampling_mode(&self) -> SamplingMode {
273 fn parse_sampling_mode(value: Option<&DynamicValue>) -> SamplingMode {
274 match value {
275 Some(value) => match value.string_value.as_ref().map(|s| s.value.as_str()) {
276 Some("on") => SamplingMode::On,
277 Some("shadow") => SamplingMode::Shadow,
278 _ => SamplingMode::None,
279 },
280 None => SamplingMode::None,
281 }
282 }
283
284 self.global_configs
285 .use_sdk_config_value("sampling_mode", parse_sampling_mode)
286 }
287
288 fn get_special_case_sampling_rate(&self) -> Option<u64> {
289 fn parse_special_case_sampling_rate(value: Option<&DynamicValue>) -> Option<u64> {
290 match value {
291 Some(value) => value.float_value.map(|rate| rate as u64),
292 None => None,
293 }
294 }
295
296 self.global_configs.use_sdk_config_value(
297 "special_case_sampling_rate",
298 parse_special_case_sampling_rate,
299 )
300 }
301}