Skip to main content

trustformers_core/ab_testing/
experiment.rs

1//! Experiment definitions and management
2
3use anyhow::Result;
4use chrono::{DateTime, Duration, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8/// Experiment configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ExperimentConfig {
11    /// Name of the experiment
12    pub name: String,
13    /// Description of what is being tested
14    pub description: String,
15    /// Control variant (baseline)
16    pub control_variant: Variant,
17    /// Treatment variants to test
18    pub treatment_variants: Vec<Variant>,
19    /// Percentage of traffic to include in test
20    pub traffic_percentage: f64,
21    /// Minimum sample size per variant
22    pub min_sample_size: usize,
23    /// Maximum duration in hours
24    pub max_duration_hours: u64,
25}
26
27/// A variant in an A/B test
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
29pub struct Variant {
30    /// Unique identifier for the variant
31    name: String,
32    /// Model or configuration identifier
33    model_id: String,
34    /// Optional configuration overrides
35    config_overrides: Option<serde_json::Value>,
36}
37
38impl Variant {
39    /// Create a new variant
40    pub fn new(name: &str, model_id: &str) -> Self {
41        Self {
42            name: name.to_string(),
43            model_id: model_id.to_string(),
44            config_overrides: None,
45        }
46    }
47
48    /// Create a variant with configuration overrides
49    pub fn with_config(name: &str, model_id: &str, config: serde_json::Value) -> Self {
50        Self {
51            name: name.to_string(),
52            model_id: model_id.to_string(),
53            config_overrides: Some(config),
54        }
55    }
56
57    /// Get variant name
58    pub fn name(&self) -> &str {
59        &self.name
60    }
61
62    /// Get model ID
63    pub fn model_id(&self) -> &str {
64        &self.model_id
65    }
66
67    /// Get configuration overrides
68    pub fn config_overrides(&self) -> Option<&serde_json::Value> {
69        self.config_overrides.as_ref()
70    }
71}
72
73/// An A/B test experiment
74#[derive(Debug, Clone)]
75pub struct Experiment {
76    /// Unique experiment ID
77    id: Uuid,
78    /// Configuration
79    config: ExperimentConfig,
80    /// Current status
81    status: ExperimentStatus,
82    /// Start time
83    start_time: Option<DateTime<Utc>>,
84    /// End time
85    end_time: Option<DateTime<Utc>>,
86    /// Metadata
87    metadata: ExperimentMetadata,
88}
89
90/// Experiment metadata
91#[derive(Debug, Clone, Default)]
92pub struct ExperimentMetadata {
93    /// Number of requests per variant
94    pub request_counts: std::collections::HashMap<String, usize>,
95    /// Last update time
96    pub last_updated: Option<DateTime<Utc>>,
97    /// Tags for categorization
98    #[allow(dead_code)]
99    pub tags: Vec<String>,
100    /// Owner/creator
101    #[allow(dead_code)]
102    pub owner: Option<String>,
103}
104
105/// Experiment status
106#[derive(Debug, Clone, PartialEq)]
107pub enum ExperimentStatus {
108    /// Experiment is being configured
109    Draft,
110    /// Experiment is running
111    Running,
112    /// Experiment is paused
113    Paused,
114    /// Experiment has concluded
115    Concluded,
116    /// Experiment was cancelled
117    Cancelled,
118}
119
120impl Experiment {
121    /// Create a new experiment
122    pub fn new(config: ExperimentConfig) -> Result<Self> {
123        // Validate configuration
124        if config.traffic_percentage <= 0.0 || config.traffic_percentage > 100.0 {
125            anyhow::bail!("Traffic percentage must be between 0 and 100");
126        }
127
128        if config.treatment_variants.is_empty() {
129            anyhow::bail!("At least one treatment variant is required");
130        }
131
132        if config.min_sample_size == 0 {
133            anyhow::bail!("Minimum sample size must be greater than 0");
134        }
135
136        Ok(Self {
137            id: Uuid::new_v4(),
138            config,
139            status: ExperimentStatus::Draft,
140            start_time: None,
141            end_time: None,
142            metadata: ExperimentMetadata::default(),
143        })
144    }
145
146    /// Get experiment ID
147    pub fn id(&self) -> &Uuid {
148        &self.id
149    }
150
151    /// Get experiment configuration
152    pub fn config(&self) -> &ExperimentConfig {
153        &self.config
154    }
155
156    /// Get current status
157    pub fn status(&self) -> ExperimentStatus {
158        self.status.clone()
159    }
160
161    /// Start the experiment
162    pub fn start(&mut self) -> Result<()> {
163        if self.status != ExperimentStatus::Draft {
164            anyhow::bail!("Can only start experiments in Draft status");
165        }
166
167        self.status = ExperimentStatus::Running;
168        self.start_time = Some(Utc::now());
169        self.metadata.last_updated = Some(Utc::now());
170        Ok(())
171    }
172
173    /// Pause the experiment
174    pub fn pause(&mut self) -> Result<()> {
175        if self.status != ExperimentStatus::Running {
176            anyhow::bail!("Can only pause running experiments");
177        }
178
179        self.status = ExperimentStatus::Paused;
180        self.metadata.last_updated = Some(Utc::now());
181        Ok(())
182    }
183
184    /// Resume the experiment
185    pub fn resume(&mut self) -> Result<()> {
186        if self.status != ExperimentStatus::Paused {
187            anyhow::bail!("Can only resume paused experiments");
188        }
189
190        self.status = ExperimentStatus::Running;
191        self.metadata.last_updated = Some(Utc::now());
192        Ok(())
193    }
194
195    /// Conclude the experiment
196    pub fn conclude(&mut self) -> Result<()> {
197        if self.status != ExperimentStatus::Running && self.status != ExperimentStatus::Paused {
198            anyhow::bail!("Can only conclude running or paused experiments");
199        }
200
201        self.status = ExperimentStatus::Concluded;
202        self.end_time = Some(Utc::now());
203        self.metadata.last_updated = Some(Utc::now());
204        Ok(())
205    }
206
207    /// Cancel the experiment
208    pub fn cancel(&mut self) -> Result<()> {
209        if self.status == ExperimentStatus::Concluded || self.status == ExperimentStatus::Cancelled
210        {
211            anyhow::bail!("Cannot cancel concluded or already cancelled experiments");
212        }
213
214        self.status = ExperimentStatus::Cancelled;
215        self.end_time = Some(Utc::now());
216        self.metadata.last_updated = Some(Utc::now());
217        Ok(())
218    }
219
220    /// Check if experiment should auto-conclude
221    pub fn should_auto_conclude(&self) -> bool {
222        if self.status != ExperimentStatus::Running {
223            return false;
224        }
225
226        // Check duration
227        if let Some(start_time) = self.start_time {
228            let elapsed = Utc::now() - start_time;
229            if elapsed > Duration::hours(self.config.max_duration_hours as i64) {
230                return true;
231            }
232        }
233
234        // Check sample sizes
235        let min_count = self.metadata.request_counts.values().min().copied().unwrap_or(0);
236        min_count >= self.config.min_sample_size
237    }
238
239    /// Get all variants (control + treatments)
240    pub fn all_variants(&self) -> Vec<&Variant> {
241        let mut variants = vec![&self.config.control_variant];
242        variants.extend(self.config.treatment_variants.iter());
243        variants
244    }
245
246    /// Update request count for a variant
247    pub fn increment_request_count(&mut self, variant_name: &str) {
248        *self.metadata.request_counts.entry(variant_name.to_string()).or_insert(0) += 1;
249        self.metadata.last_updated = Some(Utc::now());
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_experiment_lifecycle() {
259        let config = ExperimentConfig {
260            name: "Test Experiment".to_string(),
261            description: "Testing lifecycle".to_string(),
262            control_variant: Variant::new("control", "model-v1"),
263            treatment_variants: vec![Variant::new("treatment", "model-v2")],
264            traffic_percentage: 50.0,
265            min_sample_size: 100,
266            max_duration_hours: 24,
267        };
268
269        let mut experiment = Experiment::new(config).expect("operation failed in test");
270        assert_eq!(experiment.status(), ExperimentStatus::Draft);
271
272        // Start
273        experiment.start().expect("operation failed in test");
274        assert_eq!(experiment.status(), ExperimentStatus::Running);
275        assert!(experiment.start_time.is_some());
276
277        // Pause
278        experiment.pause().expect("operation failed in test");
279        assert_eq!(experiment.status(), ExperimentStatus::Paused);
280
281        // Resume
282        experiment.resume().expect("operation failed in test");
283        assert_eq!(experiment.status(), ExperimentStatus::Running);
284
285        // Conclude
286        experiment.conclude().expect("operation failed in test");
287        assert_eq!(experiment.status(), ExperimentStatus::Concluded);
288        assert!(experiment.end_time.is_some());
289    }
290
291    #[test]
292    fn test_variant_creation() {
293        let variant = Variant::new("test", "model-123");
294        assert_eq!(variant.name(), "test");
295        assert_eq!(variant.model_id(), "model-123");
296        assert!(variant.config_overrides().is_none());
297
298        let config = serde_json::json!({
299            "batch_size": 32,
300            "temperature": 0.7
301        });
302        let variant_with_config = Variant::with_config("test2", "model-456", config.clone());
303        assert_eq!(variant_with_config.config_overrides(), Some(&config));
304    }
305
306    #[test]
307    fn test_auto_conclude() {
308        let config = ExperimentConfig {
309            name: "Auto Conclude Test".to_string(),
310            description: "Testing auto conclusion".to_string(),
311            control_variant: Variant::new("control", "model-v1"),
312            treatment_variants: vec![Variant::new("treatment", "model-v2")],
313            traffic_percentage: 50.0,
314            min_sample_size: 2,
315            max_duration_hours: 24,
316        };
317
318        let mut experiment = Experiment::new(config).expect("operation failed in test");
319        experiment.start().expect("operation failed in test");
320
321        // Should not auto-conclude with no samples
322        assert!(!experiment.should_auto_conclude());
323
324        // Add samples
325        experiment.increment_request_count("control");
326        experiment.increment_request_count("control");
327        experiment.increment_request_count("treatment");
328        experiment.increment_request_count("treatment");
329
330        // Should auto-conclude when minimum samples reached
331        assert!(experiment.should_auto_conclude());
332    }
333}