Skip to main content

trustformers_core/ab_testing/
routing.rs

1//! Traffic routing and splitting strategies
2
3use super::experiment::ExperimentStatus;
4use super::{Experiment, Variant};
5use anyhow::Result;
6use scirs2_core::random::*;
7use std::collections::hash_map::DefaultHasher;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11/// Traffic routing strategy
12#[derive(Debug, Clone)]
13pub enum RoutingStrategy {
14    /// Random assignment based on hash
15    HashBased,
16    /// Round-robin assignment
17    RoundRobin,
18    /// User segment based routing
19    SegmentBased(Vec<UserSegment>),
20    /// Weighted random assignment
21    WeightedRandom(HashMap<String, f64>),
22    /// Sticky sessions (user always gets same variant)
23    Sticky,
24}
25
26/// User segment for targeted routing
27#[derive(Debug, Clone)]
28pub struct UserSegment {
29    /// Segment name
30    pub name: String,
31    /// Condition to match users
32    pub condition: SegmentCondition,
33    /// Variant to assign to this segment
34    pub variant_name: String,
35}
36
37/// Conditions for segment matching
38#[derive(Debug, Clone)]
39pub enum SegmentCondition {
40    /// User ID matches pattern
41    UserIdPattern(String),
42    /// User has specific attribute
43    HasAttribute(String, String),
44    /// User is in specific geographic region
45    GeoRegion(String),
46    /// User is using specific platform
47    Platform(Platform),
48    /// Custom condition function
49    Custom(String), // In practice, this would be a function pointer
50}
51
52/// Platform types
53#[derive(Debug, Clone, PartialEq)]
54pub enum Platform {
55    Ios,
56    Android,
57    Web,
58    Desktop,
59}
60
61/// Traffic splitter implementation
62pub struct TrafficSplitter {
63    /// Default routing strategy
64    default_strategy: RoutingStrategy,
65    /// Cache for sticky sessions
66    sticky_cache: parking_lot::RwLock<HashMap<String, String>>,
67    /// Round-robin counters
68    round_robin_counters: parking_lot::RwLock<HashMap<String, usize>>,
69}
70
71impl Default for TrafficSplitter {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl TrafficSplitter {
78    /// Create a new traffic splitter
79    pub fn new() -> Self {
80        Self {
81            default_strategy: RoutingStrategy::HashBased,
82            sticky_cache: parking_lot::RwLock::new(HashMap::new()),
83            round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
84        }
85    }
86
87    /// Create with specific strategy
88    pub fn with_strategy(strategy: RoutingStrategy) -> Self {
89        Self {
90            default_strategy: strategy,
91            sticky_cache: parking_lot::RwLock::new(HashMap::new()),
92            round_robin_counters: parking_lot::RwLock::new(HashMap::new()),
93        }
94    }
95
96    /// Route a user to a variant
97    pub fn route(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
98        // Check if experiment is running
99        if experiment.status() != ExperimentStatus::Running {
100            return Ok(experiment.config().control_variant.clone());
101        }
102
103        // Route based on strategy (traffic percentage is handled within the strategy)
104        match &self.default_strategy {
105            RoutingStrategy::HashBased => self.route_hash_based(experiment, user_id),
106            RoutingStrategy::RoundRobin => self.route_round_robin(experiment),
107            RoutingStrategy::SegmentBased(segments) => {
108                self.route_segment_based(experiment, user_id, segments)
109            },
110            RoutingStrategy::WeightedRandom(weights) => {
111                self.route_weighted_random(experiment, weights)
112            },
113            RoutingStrategy::Sticky => self.route_sticky(experiment, user_id),
114        }
115    }
116
117    /// Check if user should be included in experiment
118    #[allow(dead_code)]
119    fn should_include_in_experiment(&self, experiment: &Experiment, user_id: &str) -> bool {
120        let hash = self.hash_user_id(user_id, &experiment.id().to_string());
121        let threshold = (experiment.config().traffic_percentage / 100.0 * u64::MAX as f64) as u64;
122        hash < threshold
123    }
124
125    /// Hash-based routing
126    fn route_hash_based(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
127        let hash = self.hash_user_id(user_id, &experiment.id().to_string());
128        let control_variant = &experiment.config().control_variant;
129        let treatment_variants = &experiment.config().treatment_variants;
130
131        if treatment_variants.is_empty() {
132            return Ok(control_variant.clone());
133        }
134
135        // Use traffic percentage to decide between control and treatment
136        let traffic_percentage = experiment.config().traffic_percentage;
137        let threshold = (traffic_percentage / 100.0 * u64::MAX as f64) as u64;
138
139        if hash < threshold {
140            // User goes to treatment - pick randomly among treatment variants
141            let treatment_index = (hash as usize) % treatment_variants.len();
142            Ok(treatment_variants[treatment_index].clone())
143        } else {
144            // User goes to control
145            Ok(control_variant.clone())
146        }
147    }
148
149    /// Round-robin routing
150    fn route_round_robin(&self, experiment: &Experiment) -> Result<Variant> {
151        let variants = experiment.all_variants();
152        let mut counters = self.round_robin_counters.write();
153        let counter = counters.entry(experiment.id().to_string()).or_insert(0);
154        let index = *counter % variants.len();
155        *counter += 1;
156        Ok(variants[index].clone())
157    }
158
159    /// Segment-based routing
160    fn route_segment_based(
161        &self,
162        experiment: &Experiment,
163        user_id: &str,
164        segments: &[UserSegment],
165    ) -> Result<Variant> {
166        // Check if user matches any segment
167        for segment in segments {
168            if !self.matches_segment(user_id, &segment.condition) {
169                continue;
170            }
171
172            // Find the variant by name
173            for variant in experiment.all_variants() {
174                if variant.name() == segment.variant_name {
175                    return Ok(variant.clone());
176                }
177            }
178        }
179
180        // Fall back to hash-based routing
181        self.route_hash_based(experiment, user_id)
182    }
183
184    /// Weighted random routing
185    fn route_weighted_random(
186        &self,
187        experiment: &Experiment,
188        weights: &HashMap<String, f64>,
189    ) -> Result<Variant> {
190        let variants = experiment.all_variants();
191        let total_weight: f64 = weights.values().sum();
192
193        if total_weight == 0.0 {
194            // Fall back to equal weights
195            return self.route_hash_based(experiment, &uuid::Uuid::new_v4().to_string());
196        }
197
198        let mut rng = thread_rng();
199        let random_value: f64 = rng.random_range(0.0..total_weight);
200        let mut cumulative_weight = 0.0;
201
202        for variant in variants {
203            let weight = weights.get(variant.name()).unwrap_or(&1.0);
204            cumulative_weight += weight;
205            if random_value < cumulative_weight {
206                return Ok(variant.clone());
207            }
208        }
209
210        // Fallback to control
211        Ok(experiment.config().control_variant.clone())
212    }
213
214    /// Sticky session routing
215    fn route_sticky(&self, experiment: &Experiment, user_id: &str) -> Result<Variant> {
216        let cache_key = format!("{}:{}", experiment.id(), user_id);
217
218        // Check cache
219        if let Some(cached_variant) = self.get_cached_variant(experiment, &cache_key) {
220            return Ok(cached_variant);
221        }
222
223        // Not in cache, route and store
224        let variant = self.route_hash_based(experiment, user_id)?;
225        {
226            let mut cache = self.sticky_cache.write();
227            cache.insert(cache_key, variant.name().to_string());
228        }
229
230        Ok(variant)
231    }
232
233    /// Get cached variant if it exists
234    fn get_cached_variant(&self, experiment: &Experiment, cache_key: &str) -> Option<Variant> {
235        let cache = self.sticky_cache.read();
236        let variant_name = cache.get(cache_key)?;
237
238        for variant in experiment.all_variants() {
239            if variant.name() == variant_name {
240                return Some(variant.clone());
241            }
242        }
243        None
244    }
245
246    /// Check if user matches segment condition
247    fn matches_segment(&self, user_id: &str, condition: &SegmentCondition) -> bool {
248        match condition {
249            SegmentCondition::UserIdPattern(pattern) => {
250                // Simple pattern matching (in practice, use regex)
251                user_id.contains(pattern)
252            },
253            SegmentCondition::HasAttribute(_, _) => {
254                // Would require user attribute lookup
255                false
256            },
257            SegmentCondition::GeoRegion(_) => {
258                // Would require geo lookup
259                false
260            },
261            SegmentCondition::Platform(_) => {
262                // Would require platform detection
263                false
264            },
265            SegmentCondition::Custom(_) => {
266                // Would execute custom function
267                false
268            },
269        }
270    }
271
272    /// Hash user ID with experiment ID for consistent assignment
273    fn hash_user_id(&self, user_id: &str, experiment_id: &str) -> u64 {
274        let mut hasher = DefaultHasher::new();
275        user_id.hash(&mut hasher);
276        experiment_id.hash(&mut hasher);
277        hasher.finish()
278    }
279
280    /// Clear sticky cache for an experiment
281    pub fn clear_sticky_cache(&self, experiment_id: &str) {
282        let mut cache = self.sticky_cache.write();
283        cache.retain(|k, _| !k.starts_with(&format!("{}:", experiment_id)));
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::ab_testing::ExperimentConfig;
291
292    fn create_test_experiment() -> Experiment {
293        let config = ExperimentConfig {
294            name: "Test".to_string(),
295            description: "Test".to_string(),
296            control_variant: Variant::new("control", "v1"),
297            treatment_variants: vec![Variant::new("treatment", "v2")],
298            traffic_percentage: 100.0,
299            min_sample_size: 100,
300            max_duration_hours: 24,
301        };
302        let mut exp = Experiment::new(config).expect("operation failed in test");
303        exp.start().expect("operation failed in test");
304        exp
305    }
306
307    #[test]
308    fn test_hash_based_routing() {
309        let splitter = TrafficSplitter::new();
310        let experiment = create_test_experiment();
311
312        // Same user should always get same variant
313        let user_id = "test-user-123";
314        let variant1 = splitter.route(&experiment, user_id).expect("operation failed in test");
315        let variant2 = splitter.route(&experiment, user_id).expect("operation failed in test");
316        assert_eq!(variant1, variant2);
317    }
318
319    #[test]
320    fn test_round_robin_routing() {
321        let splitter = TrafficSplitter::with_strategy(RoutingStrategy::RoundRobin);
322        let experiment = create_test_experiment();
323
324        let mut control_count = 0;
325        let mut treatment_count = 0;
326
327        // Should alternate between variants
328        for _ in 0..10 {
329            let variant =
330                splitter.route(&experiment, "any-user").expect("operation failed in test");
331            match variant.name() {
332                "control" => control_count += 1,
333                "treatment" => treatment_count += 1,
334                name => panic!("Unexpected variant name: {}", name),
335            }
336        }
337
338        assert_eq!(control_count, 5);
339        assert_eq!(treatment_count, 5);
340    }
341
342    #[test]
343    fn test_sticky_routing() {
344        let splitter = TrafficSplitter::with_strategy(RoutingStrategy::Sticky);
345        let experiment = create_test_experiment();
346
347        let user_id = "sticky-user";
348        let first_variant = splitter.route(&experiment, user_id).expect("operation failed in test");
349
350        // Multiple calls should return same variant
351        for _ in 0..10 {
352            let variant = splitter.route(&experiment, user_id).expect("operation failed in test");
353            assert_eq!(variant, first_variant);
354        }
355    }
356
357    #[test]
358    fn test_traffic_percentage() {
359        let config = ExperimentConfig {
360            name: "Test".to_string(),
361            description: "Test".to_string(),
362            control_variant: Variant::new("control", "v1"),
363            treatment_variants: vec![Variant::new("treatment", "v2")],
364            traffic_percentage: 10.0, // Only 10% of users
365            min_sample_size: 100,
366            max_duration_hours: 24,
367        };
368        let mut exp = Experiment::new(config).expect("operation failed in test");
369        exp.start().expect("operation failed in test");
370
371        let splitter = TrafficSplitter::new();
372        let mut included_count = 0;
373
374        // Test with many users
375        for i in 0..1000 {
376            let user_id = format!("user-{}", i);
377            let variant = splitter.route(&exp, &user_id).expect("operation failed in test");
378
379            // If included in experiment, might get treatment
380            if variant.name() != "control" || splitter.should_include_in_experiment(&exp, &user_id)
381            {
382                included_count += 1;
383            }
384        }
385
386        // Should be roughly 10% (allow some variance)
387        let inclusion_rate = included_count as f64 / 1000.0;
388        assert!((inclusion_rate - 0.1).abs() < 0.05);
389    }
390}