trustformers_core/ab_testing/
experiment.rs1use anyhow::Result;
4use chrono::{DateTime, Duration, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ExperimentConfig {
11 pub name: String,
13 pub description: String,
15 pub control_variant: Variant,
17 pub treatment_variants: Vec<Variant>,
19 pub traffic_percentage: f64,
21 pub min_sample_size: usize,
23 pub max_duration_hours: u64,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
29pub struct Variant {
30 name: String,
32 model_id: String,
34 config_overrides: Option<serde_json::Value>,
36}
37
38impl Variant {
39 pub fn new(name: &str, model_id: &str) -> Self {
41 Self {
42 name: name.to_string(),
43 model_id: model_id.to_string(),
44 config_overrides: None,
45 }
46 }
47
48 pub fn with_config(name: &str, model_id: &str, config: serde_json::Value) -> Self {
50 Self {
51 name: name.to_string(),
52 model_id: model_id.to_string(),
53 config_overrides: Some(config),
54 }
55 }
56
57 pub fn name(&self) -> &str {
59 &self.name
60 }
61
62 pub fn model_id(&self) -> &str {
64 &self.model_id
65 }
66
67 pub fn config_overrides(&self) -> Option<&serde_json::Value> {
69 self.config_overrides.as_ref()
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct Experiment {
76 id: Uuid,
78 config: ExperimentConfig,
80 status: ExperimentStatus,
82 start_time: Option<DateTime<Utc>>,
84 end_time: Option<DateTime<Utc>>,
86 metadata: ExperimentMetadata,
88}
89
90#[derive(Debug, Clone, Default)]
92pub struct ExperimentMetadata {
93 pub request_counts: std::collections::HashMap<String, usize>,
95 pub last_updated: Option<DateTime<Utc>>,
97 #[allow(dead_code)]
99 pub tags: Vec<String>,
100 #[allow(dead_code)]
102 pub owner: Option<String>,
103}
104
105#[derive(Debug, Clone, PartialEq)]
107pub enum ExperimentStatus {
108 Draft,
110 Running,
112 Paused,
114 Concluded,
116 Cancelled,
118}
119
120impl Experiment {
121 pub fn new(config: ExperimentConfig) -> Result<Self> {
123 if config.traffic_percentage <= 0.0 || config.traffic_percentage > 100.0 {
125 anyhow::bail!("Traffic percentage must be between 0 and 100");
126 }
127
128 if config.treatment_variants.is_empty() {
129 anyhow::bail!("At least one treatment variant is required");
130 }
131
132 if config.min_sample_size == 0 {
133 anyhow::bail!("Minimum sample size must be greater than 0");
134 }
135
136 Ok(Self {
137 id: Uuid::new_v4(),
138 config,
139 status: ExperimentStatus::Draft,
140 start_time: None,
141 end_time: None,
142 metadata: ExperimentMetadata::default(),
143 })
144 }
145
146 pub fn id(&self) -> &Uuid {
148 &self.id
149 }
150
151 pub fn config(&self) -> &ExperimentConfig {
153 &self.config
154 }
155
156 pub fn status(&self) -> ExperimentStatus {
158 self.status.clone()
159 }
160
161 pub fn start(&mut self) -> Result<()> {
163 if self.status != ExperimentStatus::Draft {
164 anyhow::bail!("Can only start experiments in Draft status");
165 }
166
167 self.status = ExperimentStatus::Running;
168 self.start_time = Some(Utc::now());
169 self.metadata.last_updated = Some(Utc::now());
170 Ok(())
171 }
172
173 pub fn pause(&mut self) -> Result<()> {
175 if self.status != ExperimentStatus::Running {
176 anyhow::bail!("Can only pause running experiments");
177 }
178
179 self.status = ExperimentStatus::Paused;
180 self.metadata.last_updated = Some(Utc::now());
181 Ok(())
182 }
183
184 pub fn resume(&mut self) -> Result<()> {
186 if self.status != ExperimentStatus::Paused {
187 anyhow::bail!("Can only resume paused experiments");
188 }
189
190 self.status = ExperimentStatus::Running;
191 self.metadata.last_updated = Some(Utc::now());
192 Ok(())
193 }
194
195 pub fn conclude(&mut self) -> Result<()> {
197 if self.status != ExperimentStatus::Running && self.status != ExperimentStatus::Paused {
198 anyhow::bail!("Can only conclude running or paused experiments");
199 }
200
201 self.status = ExperimentStatus::Concluded;
202 self.end_time = Some(Utc::now());
203 self.metadata.last_updated = Some(Utc::now());
204 Ok(())
205 }
206
207 pub fn cancel(&mut self) -> Result<()> {
209 if self.status == ExperimentStatus::Concluded || self.status == ExperimentStatus::Cancelled
210 {
211 anyhow::bail!("Cannot cancel concluded or already cancelled experiments");
212 }
213
214 self.status = ExperimentStatus::Cancelled;
215 self.end_time = Some(Utc::now());
216 self.metadata.last_updated = Some(Utc::now());
217 Ok(())
218 }
219
220 pub fn should_auto_conclude(&self) -> bool {
222 if self.status != ExperimentStatus::Running {
223 return false;
224 }
225
226 if let Some(start_time) = self.start_time {
228 let elapsed = Utc::now() - start_time;
229 if elapsed > Duration::hours(self.config.max_duration_hours as i64) {
230 return true;
231 }
232 }
233
234 let min_count = self.metadata.request_counts.values().min().copied().unwrap_or(0);
236 min_count >= self.config.min_sample_size
237 }
238
239 pub fn all_variants(&self) -> Vec<&Variant> {
241 let mut variants = vec![&self.config.control_variant];
242 variants.extend(self.config.treatment_variants.iter());
243 variants
244 }
245
246 pub fn increment_request_count(&mut self, variant_name: &str) {
248 *self.metadata.request_counts.entry(variant_name.to_string()).or_insert(0) += 1;
249 self.metadata.last_updated = Some(Utc::now());
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_experiment_lifecycle() {
259 let config = ExperimentConfig {
260 name: "Test Experiment".to_string(),
261 description: "Testing lifecycle".to_string(),
262 control_variant: Variant::new("control", "model-v1"),
263 treatment_variants: vec![Variant::new("treatment", "model-v2")],
264 traffic_percentage: 50.0,
265 min_sample_size: 100,
266 max_duration_hours: 24,
267 };
268
269 let mut experiment = Experiment::new(config).expect("operation failed in test");
270 assert_eq!(experiment.status(), ExperimentStatus::Draft);
271
272 experiment.start().expect("operation failed in test");
274 assert_eq!(experiment.status(), ExperimentStatus::Running);
275 assert!(experiment.start_time.is_some());
276
277 experiment.pause().expect("operation failed in test");
279 assert_eq!(experiment.status(), ExperimentStatus::Paused);
280
281 experiment.resume().expect("operation failed in test");
283 assert_eq!(experiment.status(), ExperimentStatus::Running);
284
285 experiment.conclude().expect("operation failed in test");
287 assert_eq!(experiment.status(), ExperimentStatus::Concluded);
288 assert!(experiment.end_time.is_some());
289 }
290
291 #[test]
292 fn test_variant_creation() {
293 let variant = Variant::new("test", "model-123");
294 assert_eq!(variant.name(), "test");
295 assert_eq!(variant.model_id(), "model-123");
296 assert!(variant.config_overrides().is_none());
297
298 let config = serde_json::json!({
299 "batch_size": 32,
300 "temperature": 0.7
301 });
302 let variant_with_config = Variant::with_config("test2", "model-456", config.clone());
303 assert_eq!(variant_with_config.config_overrides(), Some(&config));
304 }
305
306 #[test]
307 fn test_auto_conclude() {
308 let config = ExperimentConfig {
309 name: "Auto Conclude Test".to_string(),
310 description: "Testing auto conclusion".to_string(),
311 control_variant: Variant::new("control", "model-v1"),
312 treatment_variants: vec![Variant::new("treatment", "model-v2")],
313 traffic_percentage: 50.0,
314 min_sample_size: 2,
315 max_duration_hours: 24,
316 };
317
318 let mut experiment = Experiment::new(config).expect("operation failed in test");
319 experiment.start().expect("operation failed in test");
320
321 assert!(!experiment.should_auto_conclude());
323
324 experiment.increment_request_count("control");
326 experiment.increment_request_count("control");
327 experiment.increment_request_count("treatment");
328 experiment.increment_request_count("treatment");
329
330 assert!(experiment.should_auto_conclude());
332 }
333}