Skip to main content

shodh_memory/handlers/
ab_testing.rs

1//! A/B Testing Handlers
2//!
3//! Handlers for A/B test management, metrics recording, and analysis.
4
5use axum::{
6    extract::{Path, State},
7    response::Json,
8};
9use serde::{Deserialize, Serialize};
10
11use super::state::MultiUserMemoryManager;
12use crate::ab_testing;
13use crate::errors::{AppError, ValidationErrorExt};
14use crate::relevance;
15use crate::validation;
16use std::sync::Arc;
17
18type AppState = Arc<MultiUserMemoryManager>;
19
20fn default_traffic_split() -> f32 {
21    0.5
22}
23
24fn default_min_impressions() -> u64 {
25    100
26}
27
28/// Request to create a new A/B test
29#[derive(Debug, Deserialize)]
30pub struct CreateABTestRequest {
31    pub name: String,
32    #[serde(default)]
33    pub description: Option<String>,
34    #[serde(default)]
35    pub control_weights: Option<relevance::LearnedWeights>,
36    #[serde(default)]
37    pub treatment_weights: Option<relevance::LearnedWeights>,
38    #[serde(default = "default_traffic_split")]
39    pub traffic_split: f32,
40    #[serde(default = "default_min_impressions")]
41    pub min_impressions: u64,
42    #[serde(default)]
43    pub max_duration_hours: Option<u64>,
44    #[serde(default)]
45    pub tags: Vec<String>,
46}
47
48/// Response for A/B test operations
49#[derive(Debug, Serialize)]
50pub struct ABTestResponse {
51    pub success: bool,
52    pub test_id: Option<String>,
53    pub message: String,
54}
55
56/// Request to record an impression
57#[derive(Debug, Deserialize)]
58pub struct RecordImpressionRequest {
59    pub user_id: String,
60    #[serde(default)]
61    pub relevance_score: Option<f64>,
62    #[serde(default)]
63    pub latency_us: Option<u64>,
64}
65
66/// Request to record a click
67#[derive(Debug, Deserialize)]
68pub struct RecordClickRequest {
69    pub user_id: String,
70    pub memory_id: uuid::Uuid,
71}
72
73/// Request to record feedback
74#[derive(Debug, Deserialize)]
75pub struct RecordFeedbackRequest {
76    pub user_id: String,
77    pub positive: bool,
78}
79
80/// GET /api/ab/tests - List all A/B tests
81pub async fn list_ab_tests(
82    State(state): State<AppState>,
83) -> Result<Json<serde_json::Value>, AppError> {
84    let tests = state.ab_test_manager.list_tests();
85    let summary = state.ab_test_manager.summary();
86
87    Ok(Json(serde_json::json!({
88        "success": true,
89        "tests": tests.iter().map(|t| serde_json::json!({
90            "id": t.id,
91            "name": t.config.name,
92            "description": t.config.description,
93            "status": format!("{:?}", t.status),
94            "traffic_split": t.config.traffic_split,
95            "control_impressions": t.control_metrics.impressions,
96            "treatment_impressions": t.treatment_metrics.impressions,
97            "created_at": t.created_at.to_rfc3339(),
98        })).collect::<Vec<_>>(),
99        "summary": {
100            "total_active": summary.total_active,
101            "draft": summary.draft,
102            "running": summary.running,
103            "paused": summary.paused,
104            "completed": summary.completed,
105            "archived": summary.archived,
106        }
107    })))
108}
109
110/// POST /api/ab/tests - Create a new A/B test
111pub async fn create_ab_test(
112    State(state): State<AppState>,
113    Json(req): Json<CreateABTestRequest>,
114) -> Result<Json<ABTestResponse>, AppError> {
115    let mut builder = ab_testing::ABTest::builder(&req.name)
116        .with_traffic_split(req.traffic_split)
117        .with_min_impressions(req.min_impressions);
118
119    if let Some(desc) = req.description {
120        builder = builder.with_description(&desc);
121    }
122
123    if let Some(control) = req.control_weights {
124        builder = builder.with_control(control);
125    }
126
127    if let Some(treatment) = req.treatment_weights {
128        builder = builder.with_treatment(treatment);
129    }
130
131    if let Some(hours) = req.max_duration_hours {
132        builder = builder.with_max_duration_hours(hours);
133    }
134
135    if !req.tags.is_empty() {
136        builder = builder.with_tags(req.tags);
137    }
138
139    let test = builder.build();
140
141    match state.ab_test_manager.create_test(test) {
142        Ok(test_id) => Ok(Json(ABTestResponse {
143            success: true,
144            test_id: Some(test_id),
145            message: "A/B test created successfully".to_string(),
146        })),
147        Err(e) => Ok(Json(ABTestResponse {
148            success: false,
149            test_id: None,
150            message: format!("Failed to create test: {}", e),
151        })),
152    }
153}
154
155/// GET /api/ab/tests/{test_id} - Get a specific A/B test
156pub async fn get_ab_test(
157    State(state): State<AppState>,
158    Path(test_id): Path<String>,
159) -> Result<Json<serde_json::Value>, AppError> {
160    match state.ab_test_manager.get_test(&test_id) {
161        Some(test) => Ok(Json(serde_json::json!({
162            "success": true,
163            "test": {
164                "id": test.id,
165                "name": test.config.name,
166                "description": test.config.description,
167                "status": format!("{:?}", test.status),
168                "traffic_split": test.config.traffic_split,
169                "min_impressions": test.config.min_impressions,
170                "max_duration_hours": test.config.max_duration_hours,
171                "control_weights": test.config.control_weights,
172                "treatment_weights": test.config.treatment_weights,
173                "control_metrics": {
174                    "impressions": test.control_metrics.impressions,
175                    "clicks": test.control_metrics.clicks,
176                    "ctr": if test.control_metrics.impressions > 0 {
177                        test.control_metrics.clicks as f64 / test.control_metrics.impressions as f64
178                    } else { 0.0 },
179                    "positive_feedback": test.control_metrics.positive_feedback,
180                    "negative_feedback": test.control_metrics.negative_feedback,
181                },
182                "treatment_metrics": {
183                    "impressions": test.treatment_metrics.impressions,
184                    "clicks": test.treatment_metrics.clicks,
185                    "ctr": if test.treatment_metrics.impressions > 0 {
186                        test.treatment_metrics.clicks as f64 / test.treatment_metrics.impressions as f64
187                    } else { 0.0 },
188                    "positive_feedback": test.treatment_metrics.positive_feedback,
189                    "negative_feedback": test.treatment_metrics.negative_feedback,
190                },
191                "created_at": test.created_at.to_rfc3339(),
192                "started_at": test.started_at.map(|t| t.to_rfc3339()),
193                "completed_at": test.completed_at.map(|t| t.to_rfc3339()),
194            }
195        }))),
196        None => Ok(Json(serde_json::json!({
197            "success": false,
198            "message": format!("Test not found: {}", test_id)
199        }))),
200    }
201}
202
203/// DELETE /api/ab/tests/{test_id} - Delete an A/B test
204pub async fn delete_ab_test(
205    State(state): State<AppState>,
206    Path(test_id): Path<String>,
207) -> Result<Json<ABTestResponse>, AppError> {
208    match state.ab_test_manager.delete_test(&test_id) {
209        Ok(()) => Ok(Json(ABTestResponse {
210            success: true,
211            test_id: Some(test_id),
212            message: "Test deleted successfully".to_string(),
213        })),
214        Err(e) => Ok(Json(ABTestResponse {
215            success: false,
216            test_id: None,
217            message: format!("Failed to delete test: {}", e),
218        })),
219    }
220}
221
222/// POST /api/ab/tests/{test_id}/start - Start an A/B test
223pub async fn start_ab_test(
224    State(state): State<AppState>,
225    Path(test_id): Path<String>,
226) -> Result<Json<ABTestResponse>, AppError> {
227    match state.ab_test_manager.start_test(&test_id) {
228        Ok(()) => Ok(Json(ABTestResponse {
229            success: true,
230            test_id: Some(test_id),
231            message: "Test started successfully".to_string(),
232        })),
233        Err(e) => Ok(Json(ABTestResponse {
234            success: false,
235            test_id: None,
236            message: format!("Failed to start test: {}", e),
237        })),
238    }
239}
240
241/// POST /api/ab/tests/{test_id}/pause - Pause an A/B test
242pub async fn pause_ab_test(
243    State(state): State<AppState>,
244    Path(test_id): Path<String>,
245) -> Result<Json<ABTestResponse>, AppError> {
246    match state.ab_test_manager.pause_test(&test_id) {
247        Ok(()) => Ok(Json(ABTestResponse {
248            success: true,
249            test_id: Some(test_id),
250            message: "Test paused successfully".to_string(),
251        })),
252        Err(e) => Ok(Json(ABTestResponse {
253            success: false,
254            test_id: None,
255            message: format!("Failed to pause test: {}", e),
256        })),
257    }
258}
259
260/// POST /api/ab/tests/{test_id}/resume - Resume a paused A/B test
261pub async fn resume_ab_test(
262    State(state): State<AppState>,
263    Path(test_id): Path<String>,
264) -> Result<Json<ABTestResponse>, AppError> {
265    match state.ab_test_manager.resume_test(&test_id) {
266        Ok(()) => Ok(Json(ABTestResponse {
267            success: true,
268            test_id: Some(test_id),
269            message: "Test resumed successfully".to_string(),
270        })),
271        Err(e) => Ok(Json(ABTestResponse {
272            success: false,
273            test_id: None,
274            message: format!("Failed to resume test: {}", e),
275        })),
276    }
277}
278
279/// POST /api/ab/tests/{test_id}/complete - Complete an A/B test and get results
280pub async fn complete_ab_test(
281    State(state): State<AppState>,
282    Path(test_id): Path<String>,
283) -> Result<Json<serde_json::Value>, AppError> {
284    match state.ab_test_manager.complete_test(&test_id) {
285        Ok(results) => Ok(Json(serde_json::json!({
286            "success": true,
287            "test_id": test_id,
288            "results": {
289                "is_significant": results.is_significant,
290                "confidence_level": results.confidence_level,
291                "chi_squared": results.chi_squared,
292                "p_value": results.p_value,
293                "winner": results.winner.map(|w| format!("{:?}", w)),
294                "relative_improvement": results.relative_improvement,
295                "control_ctr": results.control_ctr,
296                "treatment_ctr": results.treatment_ctr,
297                "confidence_interval": results.confidence_interval,
298                "recommendations": results.recommendations,
299            }
300        }))),
301        Err(e) => Ok(Json(serde_json::json!({
302            "success": false,
303            "message": format!("Failed to complete test: {}", e)
304        }))),
305    }
306}
307
308/// GET /api/ab/tests/{test_id}/analyze - Analyze an A/B test without completing it
309pub async fn analyze_ab_test(
310    State(state): State<AppState>,
311    Path(test_id): Path<String>,
312) -> Result<Json<serde_json::Value>, AppError> {
313    match state.ab_test_manager.analyze_test(&test_id) {
314        Ok(results) => Ok(Json(serde_json::json!({
315            "success": true,
316            "test_id": test_id,
317            "analysis": {
318                "is_significant": results.is_significant,
319                "confidence_level": results.confidence_level,
320                "chi_squared": results.chi_squared,
321                "p_value": results.p_value,
322                "winner": results.winner.map(|w| format!("{:?}", w)),
323                "relative_improvement": results.relative_improvement,
324                "control_ctr": results.control_ctr,
325                "treatment_ctr": results.treatment_ctr,
326                "confidence_interval": results.confidence_interval,
327                "recommendations": results.recommendations,
328            }
329        }))),
330        Err(e) => Ok(Json(serde_json::json!({
331            "success": false,
332            "message": format!("Failed to analyze test: {}", e)
333        }))),
334    }
335}
336
337/// POST /api/ab/tests/{test_id}/impression - Record an impression for an A/B test
338pub async fn record_ab_impression(
339    State(state): State<AppState>,
340    Path(test_id): Path<String>,
341    Json(req): Json<RecordImpressionRequest>,
342) -> Result<Json<serde_json::Value>, AppError> {
343    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
344
345    let relevance_score = req.relevance_score.unwrap_or(0.0);
346    let latency_us = req.latency_us.unwrap_or(0);
347
348    match state.ab_test_manager.record_impression(
349        &test_id,
350        &req.user_id,
351        relevance_score,
352        latency_us,
353    ) {
354        Ok(()) => {
355            let variant = state
356                .ab_test_manager
357                .get_variant(&test_id, &req.user_id)
358                .ok();
359            Ok(Json(serde_json::json!({
360                "success": true,
361                "variant": variant.map(|v| format!("{:?}", v)),
362            })))
363        }
364        Err(e) => Ok(Json(serde_json::json!({
365            "success": false,
366            "message": format!("Failed to record impression: {}", e)
367        }))),
368    }
369}
370
371/// POST /api/ab/tests/{test_id}/click - Record a click for an A/B test
372pub async fn record_ab_click(
373    State(state): State<AppState>,
374    Path(test_id): Path<String>,
375    Json(req): Json<RecordClickRequest>,
376) -> Result<Json<serde_json::Value>, AppError> {
377    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
378
379    match state
380        .ab_test_manager
381        .record_click(&test_id, &req.user_id, req.memory_id)
382    {
383        Ok(()) => Ok(Json(serde_json::json!({
384            "success": true,
385        }))),
386        Err(e) => Ok(Json(serde_json::json!({
387            "success": false,
388            "message": format!("Failed to record click: {}", e)
389        }))),
390    }
391}
392
393/// POST /api/ab/tests/{test_id}/feedback - Record feedback for an A/B test
394pub async fn record_ab_feedback(
395    State(state): State<AppState>,
396    Path(test_id): Path<String>,
397    Json(req): Json<RecordFeedbackRequest>,
398) -> Result<Json<serde_json::Value>, AppError> {
399    validation::validate_user_id(&req.user_id).map_validation_err("user_id")?;
400
401    match state
402        .ab_test_manager
403        .record_feedback(&test_id, &req.user_id, req.positive)
404    {
405        Ok(()) => Ok(Json(serde_json::json!({
406            "success": true,
407        }))),
408        Err(e) => Ok(Json(serde_json::json!({
409            "success": false,
410            "message": format!("Failed to record feedback: {}", e)
411        }))),
412    }
413}
414
415/// GET /api/ab/summary - Get summary of all A/B tests
416pub async fn get_ab_summary(
417    State(state): State<AppState>,
418) -> Result<Json<serde_json::Value>, AppError> {
419    let summary = state.ab_test_manager.summary();
420    let expired = state.ab_test_manager.check_expired_tests();
421
422    Ok(Json(serde_json::json!({
423        "success": true,
424        "summary": {
425            "total_active": summary.total_active,
426            "draft": summary.draft,
427            "running": summary.running,
428            "paused": summary.paused,
429            "completed": summary.completed,
430            "archived": summary.archived,
431        },
432        "expired_tests": expired,
433    })))
434}