trustformers_core/ab_testing/
routing.rs1use super::experiment::ExperimentStatus;
4use super::{Experiment, Variant};
5use anyhow::Result;
6use scirs2_core::random::*;
7use std::collections::hash_map::DefaultHasher;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11#[derive(Debug, Clone)]
13pub enum RoutingStrategy {
14 HashBased,
16 RoundRobin,
18 SegmentBased(Vec<UserSegment>),
20 WeightedRandom(HashMap<String, f64>),
22 Sticky,
24}
25
26#[derive(Debug, Clone)]
28pub struct UserSegment {
29 pub name: String,
31 pub condition: SegmentCondition,
33 pub variant_name: String,
35}
36
37#[derive(Debug, Clone)]
39pub enum SegmentCondition {
40 UserIdPattern(String),
42 HasAttribute(String, String),
44 GeoRegion(String),
46 Platform(Platform),
48 Custom(String), }
51
52#[derive(Debug, Clone, PartialEq)]
54pub enum Platform {
55 Ios,
56 Android,
57 Web,
58 Desktop,
59}
60
61pub struct TrafficSplitter {
63 default_strategy: RoutingStrategy,
65 sticky_cache: parking_lot::RwLock<HashMap<String, String>>,
67 round_robin_counters: parking_lot::RwLock<HashMap<String, usize>>,
69}
70
71impl Default for TrafficSplitter {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl TrafficSplitter {
78 pub fn new() -> Self {
80 Self {
81 default_strategy: RoutingStrategy::HashBased,
82 sticky_cache: parking_lot::RwLock::new(HashMap::new()),
83 round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
84 }
85 }
86
87 pub fn with_strategy(strategy: RoutingStrategy) -> Self {
89 Self {
90 default_strategy: strategy,
91 sticky_cache: parking_lot::RwLock::new(HashMap::new()),
92 round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
93 }
94 }
95
96 pub fn route(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
98 if experiment.status() != ExperimentStatus::Running {
100 return Ok(experiment.config().control_variant.clone());
101 }
102
103 match &self.default_strategy {
105 RoutingStrategy::HashBased => self.route_hash_based(experiment, user_id),
106 RoutingStrategy::RoundRobin => self.route_round_robin(experiment),
107 RoutingStrategy::SegmentBased(segments) => {
108 self.route_segment_based(experiment, user_id, segments)
109 },
110 RoutingStrategy::WeightedRandom(weights) => {
111 self.route_weighted_random(experiment, weights)
112 },
113 RoutingStrategy::Sticky => self.route_sticky(experiment, user_id),
114 }
115 }
116
117 #[allow(dead_code)]
119 fn should_include_in_experiment(&self, experiment: &Experiment, user_id: &str) -> bool {
120 let hash = self.hash_user_id(user_id, &experiment.id().to_string());
121 let threshold = (experiment.config().traffic_percentage / 100.0 * u64::MAX as f64) as u64;
122 hash < threshold
123 }
124
125 fn route_hash_based(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
127 let hash = self.hash_user_id(user_id, &experiment.id().to_string());
128 let control_variant = &experiment.config().control_variant;
129 let treatment_variants = &experiment.config().treatment_variants;
130
131 if treatment_variants.is_empty() {
132 return Ok(control_variant.clone());
133 }
134
135 let traffic_percentage = experiment.config().traffic_percentage;
137 let threshold = (traffic_percentage / 100.0 * u64::MAX as f64) as u64;
138
139 if hash < threshold {
140 let treatment_index = (hash as usize) % treatment_variants.len();
142 Ok(treatment_variants[treatment_index].clone())
143 } else {
144 Ok(control_variant.clone())
146 }
147 }
148
149 fn route_round_robin(&self, experiment: &Experiment) -> Result<Variant> {
151 let variants = experiment.all_variants();
152 let mut counters = self.round_robin_counters.write();
153 let counter = counters.entry(experiment.id().to_string()).or_insert(0);
154 let index = *counter % variants.len();
155 *counter += 1;
156 Ok(variants[index].clone())
157 }
158
159 fn route_segment_based(
161 &self,
162 experiment: &Experiment,
163 user_id: &str,
164 segments: &[UserSegment],
165 ) -> Result<Variant> {
166 for segment in segments {
168 if !self.matches_segment(user_id, &segment.condition) {
169 continue;
170 }
171
172 for variant in experiment.all_variants() {
174 if variant.name() == segment.variant_name {
175 return Ok(variant.clone());
176 }
177 }
178 }
179
180 self.route_hash_based(experiment, user_id)
182 }
183
184 fn route_weighted_random(
186 &self,
187 experiment: &Experiment,
188 weights: &HashMap<String, f64>,
189 ) -> Result<Variant> {
190 let variants = experiment.all_variants();
191 let total_weight: f64 = weights.values().sum();
192
193 if total_weight == 0.0 {
194 return self.route_hash_based(experiment, &uuid::Uuid::new_v4().to_string());
196 }
197
198 let mut rng = thread_rng();
199 let random_value: f64 = rng.random_range(0.0..total_weight);
200 let mut cumulative_weight = 0.0;
201
202 for variant in variants {
203 let weight = weights.get(variant.name()).unwrap_or(&1.0);
204 cumulative_weight += weight;
205 if random_value < cumulative_weight {
206 return Ok(variant.clone());
207 }
208 }
209
210 Ok(experiment.config().control_variant.clone())
212 }
213
214 fn route_sticky(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
216 let cache_key = format!("{}:{}", experiment.id(), user_id);
217
218 if let Some(cached_variant) = self.get_cached_variant(experiment, &cache_key) {
220 return Ok(cached_variant);
221 }
222
223 let variant = self.route_hash_based(experiment, user_id)?;
225 {
226 let mut cache = self.sticky_cache.write();
227 cache.insert(cache_key, variant.name().to_string());
228 }
229
230 Ok(variant)
231 }
232
233 fn get_cached_variant(&self, experiment: &Experiment, cache_key: &str) -> Option<Variant> {
235 let cache = self.sticky_cache.read();
236 let variant_name = cache.get(cache_key)?;
237
238 for variant in experiment.all_variants() {
239 if variant.name() == variant_name {
240 return Some(variant.clone());
241 }
242 }
243 None
244 }
245
246 fn matches_segment(&self, user_id: &str, condition: &SegmentCondition) -> bool {
248 match condition {
249 SegmentCondition::UserIdPattern(pattern) => {
250 user_id.contains(pattern)
252 },
253 SegmentCondition::HasAttribute(_, _) => {
254 false
256 },
257 SegmentCondition::GeoRegion(_) => {
258 false
260 },
261 SegmentCondition::Platform(_) => {
262 false
264 },
265 SegmentCondition::Custom(_) => {
266 false
268 },
269 }
270 }
271
272 fn hash_user_id(&self, user_id: &str, experiment_id: &str) -> u64 {
274 let mut hasher = DefaultHasher::new();
275 user_id.hash(&mut hasher);
276 experiment_id.hash(&mut hasher);
277 hasher.finish()
278 }
279
280 pub fn clear_sticky_cache(&self, experiment_id: &str) {
282 let mut cache = self.sticky_cache.write();
283 cache.retain(|k, _| !k.starts_with(&format!("{}:", experiment_id)));
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::ab_testing::ExperimentConfig;
291
292 fn create_test_experiment() -> Experiment {
293 let config = ExperimentConfig {
294 name: "Test".to_string(),
295 description: "Test".to_string(),
296 control_variant: Variant::new("control", "v1"),
297 treatment_variants: vec![Variant::new("treatment", "v2")],
298 traffic_percentage: 100.0,
299 min_sample_size: 100,
300 max_duration_hours: 24,
301 };
302 let mut exp = Experiment::new(config).expect("operation failed in test");
303 exp.start().expect("operation failed in test");
304 exp
305 }
306
307 #[test]
308 fn test_hash_based_routing() {
309 let splitter = TrafficSplitter::new();
310 let experiment = create_test_experiment();
311
312 let user_id = "test-user-123";
314 let variant1 = splitter.route(&experiment, user_id).expect("operation failed in test");
315 let variant2 = splitter.route(&experiment, user_id).expect("operation failed in test");
316 assert_eq!(variant1, variant2);
317 }
318
319 #[test]
320 fn test_round_robin_routing() {
321 let splitter = TrafficSplitter::with_strategy(RoutingStrategy::RoundRobin);
322 let experiment = create_test_experiment();
323
324 let mut control_count = 0;
325 let mut treatment_count = 0;
326
327 for _ in 0..10 {
329 let variant =
330 splitter.route(&experiment, "any-user").expect("operation failed in test");
331 match variant.name() {
332 "control" => control_count += 1,
333 "treatment" => treatment_count += 1,
334 name => panic!("Unexpected variant name: {}", name),
335 }
336 }
337
338 assert_eq!(control_count, 5);
339 assert_eq!(treatment_count, 5);
340 }
341
342 #[test]
343 fn test_sticky_routing() {
344 let splitter = TrafficSplitter::with_strategy(RoutingStrategy::Sticky);
345 let experiment = create_test_experiment();
346
347 let user_id = "sticky-user";
348 let first_variant = splitter.route(&experiment, user_id).expect("operation failed in test");
349
350 for _ in 0..10 {
352 let variant = splitter.route(&experiment, user_id).expect("operation failed in test");
353 assert_eq!(variant, first_variant);
354 }
355 }
356
357 #[test]
358 fn test_traffic_percentage() {
359 let config = ExperimentConfig {
360 name: "Test".to_string(),
361 description: "Test".to_string(),
362 control_variant: Variant::new("control", "v1"),
363 treatment_variants: vec![Variant::new("treatment", "v2")],
364 traffic_percentage: 10.0, min_sample_size: 100,
366 max_duration_hours: 24,
367 };
368 let mut exp = Experiment::new(config).expect("operation failed in test");
369 exp.start().expect("operation failed in test");
370
371 let splitter = TrafficSplitter::new();
372 let mut included_count = 0;
373
374 for i in 0..1000 {
376 let user_id = format!("user-{}", i);
377 let variant = splitter.route(&exp, &user_id).expect("operation failed in test");
378
379 if variant.name() != "control" || splitter.should_include_in_experiment(&exp, &user_id)
381 {
382 included_count += 1;
383 }
384 }
385
386 let inclusion_rate = included_count as f64 / 1000.0;
388 assert!((inclusion_rate - 0.1).abs() < 0.05);
389 }
390}