statsig_client/
batch.rs

1//! Batch processing module for optimizing API requests
2//!
3//! This module handles batching multiple gate and config requests into single API calls
4//! to reduce network overhead and improve performance.
5
6use crate::{
7    api::{ConfigEvaluationResult, GateEvaluationResult},
8    config::StatsigClientConfig,
9    error::Result,
10    transport::StatsigTransport,
11    user::User,
12};
13use std::collections::HashMap;
14use tokio::sync::{mpsc, oneshot};
15use tracing::{error, info};
16
17/// Represents different types of batch requests
18#[derive(Debug)]
19pub enum BatchRequest {
20    CheckGates {
21        gate_names: Vec<String>,
22        user: User,
23        response_tx: oneshot::Sender<Result<Vec<GateEvaluationResult>>>,
24    },
25    GetConfigs {
26        config_names: Vec<String>,
27        user: User,
28        response_tx: oneshot::Sender<Result<Vec<ConfigEvaluationResult>>>,
29    },
30}
31
32/// Handles batch processing of API requests
33pub struct BatchProcessor {
34    receiver: mpsc::Receiver<BatchRequest>,
35    shutdown_rx: tokio::sync::broadcast::Receiver<()>,
36}
37
38impl BatchProcessor {
39    /// Creates a new batch processor
40    pub fn new(
41        receiver: mpsc::Receiver<BatchRequest>,
42        shutdown_rx: tokio::sync::broadcast::Receiver<()>,
43    ) -> Self {
44        Self {
45            receiver,
46            shutdown_rx,
47        }
48    }
49
50    /// Runs the batch processor loop
51    pub async fn run(mut self, transport: StatsigTransport, config: StatsigClientConfig) {
52        let mut interval = tokio::time::interval(config.batch_flush_interval);
53        let mut gate_requests = Vec::new();
54        let mut config_requests = Vec::new();
55
56        loop {
57            tokio::select! {
58                Some(request) = self.receiver.recv() => {
59                    match request {
60                        BatchRequest::CheckGates { .. } => gate_requests.push(request),
61                        BatchRequest::GetConfigs { .. } => config_requests.push(request),
62                    }
63
64                    // Process if batch size reached
65                    if gate_requests.len() >= config.batch_size || config_requests.len() >= config.batch_size {
66                        Self::process_gate_batch(&transport, &mut gate_requests).await;
67                        Self::process_config_batch(&transport, &mut config_requests).await;
68                    }
69                }
70                _ = interval.tick() => {
71                    if !gate_requests.is_empty() {
72                        Self::process_gate_batch(&transport, &mut gate_requests).await;
73                    }
74                    if !config_requests.is_empty() {
75                        Self::process_config_batch(&transport, &mut config_requests).await;
76                    }
77                }
78                _ = self.shutdown_rx.recv() => {
79                    info!("Batch processor shutting down");
80                    break;
81                }
82            }
83        }
84    }
85
86    /// Processes a batch of gate requests
87    async fn process_gate_batch(transport: &StatsigTransport, requests: &mut Vec<BatchRequest>) {
88        if requests.is_empty() {
89            return;
90        }
91
92        let batch = std::mem::take(requests);
93
94        // Group by user for efficiency
95        let mut user_groups: HashMap<String, Vec<_>> = HashMap::new();
96        for request in batch {
97            if let BatchRequest::CheckGates { user, .. } = &request {
98                let user_hash = Self::hash_user_for_batch(user);
99                user_groups.entry(user_hash).or_default().push(request);
100            }
101        }
102
103        for (_user_hash, group_requests) in user_groups {
104            if let Some(first_request) = group_requests.first() {
105                if let BatchRequest::CheckGates { user, .. } = first_request {
106                    let all_gate_names: Vec<String> = group_requests
107                        .iter()
108                        .filter_map(|req| {
109                            if let BatchRequest::CheckGates { gate_names, .. } = req {
110                                Some(gate_names.clone())
111                            } else {
112                                None
113                            }
114                        })
115                        .flatten()
116                        .collect();
117
118                    match transport.check_gates(all_gate_names, user).await {
119                        Ok(results) => {
120                            // Distribute results back to requesters
121                            for request in group_requests {
122                                if let BatchRequest::CheckGates {
123                                    gate_names,
124                                    response_tx,
125                                    ..
126                                } = request
127                                {
128                                    let filtered_results: Vec<GateEvaluationResult> = results
129                                        .iter()
130                                        .filter(|result| gate_names.contains(&result.name))
131                                        .cloned()
132                                        .collect();
133                                    let _ = response_tx.send(Ok(filtered_results));
134                                }
135                            }
136                        }
137                        Err(e) => {
138                            error!("Failed to fetch gates from API: {:?}", e);
139                            // Send error to all requesters
140                            for request in group_requests {
141                                if let BatchRequest::CheckGates { response_tx, .. } = request {
142                                    let _ = response_tx.send(Err(e.clone()));
143                                }
144                            }
145                        }
146                    }
147                }
148            }
149        }
150    }
151
152    /// Processes a batch of config requests
153    async fn process_config_batch(transport: &StatsigTransport, requests: &mut Vec<BatchRequest>) {
154        if requests.is_empty() {
155            return;
156        }
157
158        let batch = std::mem::take(requests);
159
160        // Process each config request individually for now (could be optimized)
161        for request in batch {
162            if let BatchRequest::GetConfigs {
163                config_names,
164                user,
165                response_tx,
166            } = request
167            {
168                let results = Self::fetch_configs_from_api(transport, &config_names, &user).await;
169                let _ = response_tx.send(results);
170            }
171        }
172    }
173
174    /// Fetches configs from the Statsig API
175    async fn fetch_configs_from_api(
176        transport: &StatsigTransport,
177        config_names: &[String],
178        user: &User,
179    ) -> Result<Vec<ConfigEvaluationResult>> {
180        let mut results = Vec::new();
181
182        for config_name in config_names {
183            let config_result = transport.get_config(config_name, user).await?;
184            results.push(config_result);
185        }
186
187        Ok(results)
188    }
189
190    /// Hashes user for batch grouping
191    fn hash_user_for_batch(user: &User) -> String {
192        user.hash_for_cache()
193    }
194}