Skip to main content

trustformers_wasm/optimization/
batch_processing.rs

1//! Batch processing support for efficient inference
2
3use crate::core::tensor::WasmTensor;
4use js_sys::{Date, Promise};
5use serde::{Deserialize, Serialize};
6use std::string::String;
7use std::vec::Vec;
8use wasm_bindgen::prelude::*;
9use wasm_bindgen_futures::JsFuture;
10
11/// Batching strategies for different use cases
12#[wasm_bindgen]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum BatchingStrategy {
15    /// Process immediately without batching
16    Immediate,
17    /// Fixed-size batches
18    FixedSize,
19    /// Dynamic batching based on timing
20    Dynamic,
21    /// Adaptive batching based on load and performance
22    Adaptive,
23}
24
25/// Batch processing configuration
26#[wasm_bindgen]
27#[derive(Debug, Clone)]
28pub struct BatchConfig {
29    strategy: BatchingStrategy,
30    max_batch_size: usize,
31    timeout_ms: u32,
32    target_latency_ms: u32,
33    memory_limit_mb: f32,
34    enable_prioritization: bool,
35    enable_preemption: bool,
36}
37
38#[wasm_bindgen]
39impl BatchConfig {
40    /// Create a new batch configuration
41    #[wasm_bindgen(constructor)]
42    pub fn new(strategy: BatchingStrategy, max_batch_size: usize) -> Self {
43        Self {
44            strategy,
45            max_batch_size,
46            timeout_ms: 100,       // 100ms default timeout
47            target_latency_ms: 50, // Target 50ms latency
48            memory_limit_mb: 100.0,
49            enable_prioritization: false,
50            enable_preemption: false,
51        }
52    }
53
54    /// Create configuration optimized for real-time applications
55    pub fn real_time() -> Self {
56        Self {
57            strategy: BatchingStrategy::Dynamic,
58            max_batch_size: 4,
59            timeout_ms: 10,
60            target_latency_ms: 20,
61            memory_limit_mb: 50.0,
62            enable_prioritization: true,
63            enable_preemption: true,
64        }
65    }
66
67    /// Create configuration optimized for throughput
68    pub fn throughput() -> Self {
69        Self {
70            strategy: BatchingStrategy::FixedSize,
71            max_batch_size: 32,
72            timeout_ms: 500,
73            target_latency_ms: 200,
74            memory_limit_mb: 500.0,
75            enable_prioritization: false,
76            enable_preemption: false,
77        }
78    }
79
80    /// Create configuration optimized for mobile devices
81    pub fn mobile() -> Self {
82        Self {
83            strategy: BatchingStrategy::Adaptive,
84            max_batch_size: 2,
85            timeout_ms: 50,
86            target_latency_ms: 100,
87            memory_limit_mb: 20.0,
88            enable_prioritization: true,
89            enable_preemption: false,
90        }
91    }
92
93    /// Set timeout for batch completion
94    pub fn set_timeout_ms(&mut self, timeout_ms: u32) {
95        self.timeout_ms = timeout_ms;
96    }
97
98    /// Set target latency
99    pub fn set_target_latency_ms(&mut self, latency_ms: u32) {
100        self.target_latency_ms = latency_ms;
101    }
102
103    /// Set memory limit
104    pub fn set_memory_limit_mb(&mut self, limit_mb: f32) {
105        self.memory_limit_mb = limit_mb;
106    }
107
108    /// Enable request prioritization
109    pub fn enable_prioritization(&mut self) {
110        self.enable_prioritization = true;
111    }
112
113    /// Enable batch preemption for high-priority requests
114    pub fn enable_preemption(&mut self) {
115        self.enable_preemption = true;
116    }
117}
118
119/// Priority levels for batch requests
120#[wasm_bindgen]
121#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
122pub enum Priority {
123    Low = 0,
124    Normal = 1,
125    High = 2,
126    Critical = 3,
127}
128
129/// Batch request with metadata
130#[derive(Debug, Clone)]
131pub struct BatchRequest {
132    pub id: String,
133    pub input: WasmTensor,
134    pub priority: Priority,
135    pub timestamp: f64,
136    pub timeout_ms: Option<u32>,
137    pub callback: Option<js_sys::Function>,
138}
139
140/// Batch response with results
141#[wasm_bindgen]
142pub struct BatchResponse {
143    request_id: String,
144    result: Option<WasmTensor>,
145    error: Option<String>,
146    processing_time_ms: f64,
147    queue_time_ms: f64,
148    batch_size: usize,
149}
150
151#[wasm_bindgen]
152impl BatchResponse {
153    /// Get the request ID
154    #[wasm_bindgen(getter)]
155    pub fn request_id(&self) -> String {
156        self.request_id.clone()
157    }
158
159    /// Get the result tensor
160    pub fn result(&self) -> Option<WasmTensor> {
161        self.result.clone()
162    }
163
164    /// Get error message if any
165    #[wasm_bindgen(getter)]
166    pub fn error(&self) -> Option<String> {
167        self.error.clone()
168    }
169
170    /// Get processing time in milliseconds
171    #[wasm_bindgen(getter)]
172    pub fn processing_time_ms(&self) -> f64 {
173        self.processing_time_ms
174    }
175
176    /// Get queue time in milliseconds
177    #[wasm_bindgen(getter)]
178    pub fn queue_time_ms(&self) -> f64 {
179        self.queue_time_ms
180    }
181
182    /// Get the batch size this request was processed in
183    #[wasm_bindgen(getter)]
184    pub fn batch_size(&self) -> usize {
185        self.batch_size
186    }
187
188    /// Check if the request was successful
189    #[wasm_bindgen(getter)]
190    pub fn is_success(&self) -> bool {
191        self.error.is_none() && self.result.is_some()
192    }
193}
194
195/// Batch statistics for monitoring
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct BatchStats {
198    pub total_requests: usize,
199    pub completed_requests: usize,
200    pub failed_requests: usize,
201    pub average_batch_size: f32,
202    pub average_processing_time_ms: f32,
203    pub average_queue_time_ms: f32,
204    pub throughput_requests_per_second: f32,
205    pub memory_usage_mb: f32,
206}
207
208/// Batch processor for efficient inference
209#[wasm_bindgen]
210pub struct BatchProcessor {
211    config: BatchConfig,
212    pending_requests: Vec<BatchRequest>,
213    #[allow(dead_code)]
214    active_batch: Option<Vec<BatchRequest>>,
215    stats: BatchStats,
216    last_batch_time: f64,
217    adaptive_batch_size: usize,
218    request_counter: usize,
219}
220
221#[wasm_bindgen]
222impl BatchProcessor {
223    /// Create a new batch processor
224    #[wasm_bindgen(constructor)]
225    pub fn new(config: BatchConfig) -> Self {
226        let adaptive_batch_size = config.max_batch_size.min(4); // Start with smaller batches
227
228        Self {
229            config,
230            pending_requests: Vec::new(),
231            active_batch: None,
232            stats: BatchStats {
233                total_requests: 0,
234                completed_requests: 0,
235                failed_requests: 0,
236                average_batch_size: 0.0,
237                average_processing_time_ms: 0.0,
238                average_queue_time_ms: 0.0,
239                throughput_requests_per_second: 0.0,
240                memory_usage_mb: 0.0,
241            },
242            last_batch_time: Date::now(),
243            adaptive_batch_size,
244            request_counter: 0,
245        }
246    }
247
248    /// Add a request to the batch queue
249    pub fn add_request(
250        &mut self,
251        input: WasmTensor,
252        priority: Priority,
253        timeout_ms: Option<u32>,
254    ) -> String {
255        self.request_counter += 1;
256        let request_id = format!("req_{counter}", counter = self.request_counter);
257
258        let request = BatchRequest {
259            id: request_id.clone(),
260            input,
261            priority,
262            timestamp: Date::now(),
263            timeout_ms,
264            callback: None,
265        };
266
267        // Insert request based on priority
268        if self.config.enable_prioritization {
269            let insert_pos = self
270                .pending_requests
271                .iter()
272                .position(|r| r.priority < priority)
273                .unwrap_or(self.pending_requests.len());
274            self.pending_requests.insert(insert_pos, request);
275        } else {
276            self.pending_requests.push(request);
277        }
278
279        self.stats.total_requests += 1;
280
281        web_sys::console::log_1(
282            &format!(
283                "Added request {} to batch queue (priority: {:?})",
284                request_id, priority
285            )
286            .into(),
287        );
288
289        request_id
290    }
291
292    /// Process pending requests based on batching strategy
293    pub async fn process_batch(&mut self) -> Result<Vec<BatchResponse>, JsValue> {
294        if self.pending_requests.is_empty() {
295            return Ok(Vec::new());
296        }
297
298        let batch_size = self.determine_batch_size();
299        let batch_requests = self.extract_batch(batch_size);
300
301        if batch_requests.is_empty() {
302            return Ok(Vec::new());
303        }
304
305        let batch_start_time = Date::now();
306
307        web_sys::console::log_1(
308            &format!(
309                "Processing batch of {len} requests",
310                len = batch_requests.len()
311            )
312            .into(),
313        );
314
315        // Combine inputs into a single batch tensor
316        let batch_inputs = self.combine_inputs(&batch_requests)?;
317
318        // Process the batch (this would call the actual inference engine)
319        let batch_results = self.process_batch_inference(&batch_inputs).await?;
320
321        let processing_time = Date::now() - batch_start_time;
322
323        // Split results back to individual responses
324        let responses = self.create_responses(
325            batch_requests,
326            batch_results,
327            processing_time,
328            batch_start_time,
329        );
330
331        // Update statistics
332        self.update_stats(&responses, processing_time);
333
334        // Update adaptive batch size
335        if self.config.strategy == BatchingStrategy::Adaptive {
336            self.update_adaptive_batch_size(processing_time, responses.len());
337        }
338
339        self.last_batch_time = Date::now();
340
341        Ok(responses)
342    }
343
344    /// Check if a batch is ready to process
345    pub fn is_batch_ready(&self) -> bool {
346        if self.pending_requests.is_empty() {
347            return false;
348        }
349
350        match self.config.strategy {
351            BatchingStrategy::Immediate => true,
352            BatchingStrategy::FixedSize => {
353                self.pending_requests.len() >= self.config.max_batch_size
354            },
355            BatchingStrategy::Dynamic => {
356                let elapsed = Date::now() - self.last_batch_time;
357                elapsed >= self.config.timeout_ms as f64
358                    || self.pending_requests.len() >= self.config.max_batch_size
359            },
360            BatchingStrategy::Adaptive => {
361                let elapsed = Date::now() - self.last_batch_time;
362                elapsed >= self.config.timeout_ms as f64
363                    || self.pending_requests.len() >= self.adaptive_batch_size
364            },
365        }
366    }
367
368    /// Get current queue length
369    #[wasm_bindgen(getter)]
370    pub fn queue_length(&self) -> usize {
371        self.pending_requests.len()
372    }
373
374    /// Get batch statistics
375    pub fn get_stats(&self) -> String {
376        format!(
377            "Batch Stats: {} total, {} completed, {} failed, avg batch size: {:.1}, avg processing: {:.1}ms, throughput: {:.1} req/s",
378            self.stats.total_requests,
379            self.stats.completed_requests,
380            self.stats.failed_requests,
381            self.stats.average_batch_size,
382            self.stats.average_processing_time_ms,
383            self.stats.throughput_requests_per_second
384        )
385    }
386
387    /// Clear all pending requests
388    pub fn clear_queue(&mut self) {
389        self.pending_requests.clear();
390        web_sys::console::log_1(&"Batch queue cleared".into());
391    }
392
393    /// Update configuration
394    pub fn update_config(&mut self, config: BatchConfig) {
395        self.config = config;
396        self.adaptive_batch_size = self.config.max_batch_size.min(4);
397        web_sys::console::log_1(&"Batch configuration updated".into());
398    }
399
400    // Private helper methods
401
402    fn determine_batch_size(&self) -> usize {
403        match self.config.strategy {
404            BatchingStrategy::Immediate => 1,
405            BatchingStrategy::FixedSize => {
406                self.config.max_batch_size.min(self.pending_requests.len())
407            },
408            BatchingStrategy::Dynamic => {
409                let elapsed = Date::now() - self.last_batch_time;
410                if elapsed >= self.config.timeout_ms as f64 {
411                    self.pending_requests.len().min(self.config.max_batch_size)
412                } else {
413                    self.config.max_batch_size.min(self.pending_requests.len())
414                }
415            },
416            BatchingStrategy::Adaptive => self.adaptive_batch_size.min(self.pending_requests.len()),
417        }
418    }
419
420    fn extract_batch(&mut self, batch_size: usize) -> Vec<BatchRequest> {
421        let actual_size = batch_size.min(self.pending_requests.len());
422        self.pending_requests.drain(0..actual_size).collect()
423    }
424
425    fn combine_inputs(&self, requests: &[BatchRequest]) -> Result<WasmTensor, JsValue> {
426        if requests.is_empty() {
427            return Err("No requests to process".into());
428        }
429
430        if requests.len() == 1 {
431            return Ok(requests[0].input.clone());
432        }
433
434        // Get the shape of the first tensor to validate compatibility
435        let first_shape = requests[0].input.shape();
436        let batch_size = requests.len();
437
438        // Validate that all tensors have compatible shapes for batching
439        for (i, request) in requests.iter().enumerate().skip(1) {
440            let current_shape = request.input.shape();
441            if current_shape.len() != first_shape.len() {
442                return Err(format!(
443                    "Tensor {} has incompatible rank: {} vs {}",
444                    i,
445                    current_shape.len(),
446                    first_shape.len()
447                )
448                .into());
449            }
450
451            // Check that all dimensions except the first (batch dimension) match
452            for (dim_idx, (&current_dim, &first_dim)) in
453                current_shape[1..].iter().zip(first_shape[1..].iter()).enumerate()
454            {
455                if current_dim != first_dim {
456                    return Err(format!(
457                        "Tensor {} has incompatible shape at dimension {}: {} vs {}",
458                        i,
459                        dim_idx + 1,
460                        current_dim,
461                        first_dim
462                    )
463                    .into());
464                }
465            }
466        }
467
468        // Create new shape with batch dimension
469        let mut batched_shape = first_shape.clone();
470        batched_shape[0] = batch_size;
471
472        // Calculate total size for the batched tensor
473        let total_elements = batched_shape.iter().product::<usize>();
474        let mut batched_data = vec![0.0f32; total_elements];
475
476        // Copy data from each tensor into the batched tensor
477        let elements_per_batch = first_shape.iter().product::<usize>();
478
479        for (batch_idx, request) in requests.iter().enumerate() {
480            let tensor_data = request.input.data();
481            let start_idx = batch_idx * elements_per_batch;
482            let end_idx = start_idx + elements_per_batch.min(tensor_data.len());
483
484            if end_idx <= batched_data.len() {
485                batched_data[start_idx..end_idx]
486                    .copy_from_slice(&tensor_data[..elements_per_batch.min(tensor_data.len())]);
487            }
488        }
489
490        // Create the batched tensor
491        WasmTensor::new(batched_data, batched_shape)
492    }
493
494    async fn process_batch_inference(
495        &self,
496        batch_input: &WasmTensor,
497    ) -> Result<Vec<WasmTensor>, JsValue> {
498        // Simulate inference processing time
499        let processing_delay = 10.0 + (self.pending_requests.len() as f64 * 2.0);
500
501        // Simulate inference delay (in a real implementation, this would be actual model inference)
502        let delay_promise = Promise::new(&mut |resolve, _| {
503            let _timeout_id = web_sys::window()
504                .expect("window should be available in browser context")
505                .set_timeout_with_callback_and_timeout_and_arguments_0(
506                    &resolve,
507                    processing_delay as i32,
508                )
509                .expect("set_timeout should succeed with valid callback");
510            // Note: In a real app, you'd want to track timeout_id for cleanup
511        });
512
513        JsFuture::from(delay_promise).await?;
514
515        // Perform actual inference on the batched input
516        let batch_shape = batch_input.shape();
517        let batch_size = batch_shape[0];
518
519        // For demonstration, perform a simple transformation
520        // In a real implementation, this would use the model's forward pass
521        let batch_output = match batch_shape.len() {
522            2 => {
523                // For 2D tensors (batch_size, features), return logits
524                let output_features = 10; // Assuming classification with 10 classes
525                batch_input.matmul(&WasmTensor::randn(vec![batch_shape[1], output_features])?)?
526            },
527            3 => {
528                // For 3D tensors (batch_size, seq_len, features), return sequence output
529                let _output_features = batch_shape[2]; // Same feature size
530                batch_input.relu() // Simple activation for demonstration
531            },
532            _ => {
533                // For other shapes, apply element-wise transformation
534                batch_input.relu()
535            },
536        };
537
538        // Split the batched output back into individual results
539        let output_shape = batch_output.shape();
540        let elements_per_batch = output_shape[1..].iter().product::<usize>();
541        let output_data = batch_output.data();
542
543        let mut results = Vec::new();
544        for batch_idx in 0..batch_size {
545            let start_idx = batch_idx * elements_per_batch;
546            let end_idx = start_idx + elements_per_batch;
547
548            if end_idx <= output_data.len() {
549                let batch_data = output_data[start_idx..end_idx].to_vec();
550                let mut individual_shape = output_shape[1..].to_vec();
551                individual_shape.insert(0, 1); // Add batch dimension of 1
552
553                results.push(WasmTensor::new(batch_data, individual_shape)?);
554            }
555        }
556
557        if results.len() != batch_size {
558            return Err(format!(
559                "Expected {batch_size} results but got {len}",
560                len = results.len()
561            )
562            .into());
563        }
564
565        Ok(results)
566    }
567
568    fn create_responses(
569        &self,
570        requests: Vec<BatchRequest>,
571        results: Vec<WasmTensor>,
572        processing_time: f64,
573        batch_start_time: f64,
574    ) -> Vec<BatchResponse> {
575        let batch_size = requests.len();
576        requests
577            .into_iter()
578            .zip(results)
579            .map(|(request, result)| {
580                let queue_time = batch_start_time - request.timestamp;
581
582                BatchResponse {
583                    request_id: request.id,
584                    result: Some(result),
585                    error: None,
586                    processing_time_ms: processing_time,
587                    queue_time_ms: queue_time,
588                    batch_size,
589                }
590            })
591            .collect()
592    }
593
594    fn update_stats(&mut self, responses: &[BatchResponse], processing_time: f64) {
595        let successful = responses.iter().filter(|r| r.is_success()).count();
596        let failed = responses.len() - successful;
597
598        self.stats.completed_requests += successful;
599        self.stats.failed_requests += failed;
600
601        // Update running averages
602        let total_completed = self.stats.completed_requests as f32;
603        if total_completed > 0.0 {
604            self.stats.average_batch_size = (self.stats.average_batch_size
605                * (total_completed - successful as f32)
606                + responses.len() as f32)
607                / total_completed;
608
609            self.stats.average_processing_time_ms = (self.stats.average_processing_time_ms
610                * (total_completed - successful as f32)
611                + processing_time as f32)
612                / total_completed;
613
614            if let Some(first_response) = responses.first() {
615                self.stats.average_queue_time_ms = (self.stats.average_queue_time_ms
616                    * (total_completed - 1.0)
617                    + first_response.queue_time_ms as f32)
618                    / total_completed;
619            }
620        }
621
622        // Calculate throughput (requests per second)
623        if processing_time > 0.0 {
624            self.stats.throughput_requests_per_second =
625                (responses.len() as f32) / (processing_time / 1000.0) as f32;
626        }
627    }
628
629    fn update_adaptive_batch_size(&mut self, processing_time: f64, batch_size: usize) {
630        let target_latency = self.config.target_latency_ms as f64;
631
632        if processing_time > target_latency * 1.5 {
633            // Decrease batch size if we're too slow
634            self.adaptive_batch_size = (self.adaptive_batch_size - 1).max(1);
635        } else if processing_time < target_latency * 0.7 && batch_size == self.adaptive_batch_size {
636            // Increase batch size if we're fast and the batch was full
637            self.adaptive_batch_size =
638                (self.adaptive_batch_size + 1).min(self.config.max_batch_size);
639        }
640
641        web_sys::console::log_1(
642            &format!(
643                "Adaptive batch size updated to {}",
644                self.adaptive_batch_size
645            )
646            .into(),
647        );
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn test_batch_config() {
657        let config = BatchConfig::real_time();
658        assert_eq!(config.strategy, BatchingStrategy::Dynamic);
659        assert!(config.enable_prioritization);
660
661        let throughput_config = BatchConfig::throughput();
662        assert_eq!(throughput_config.max_batch_size, 32);
663    }
664
665    #[test]
666    fn test_priority_ordering() {
667        assert!(Priority::Critical > Priority::High);
668        assert!(Priority::High > Priority::Normal);
669        assert!(Priority::Normal > Priority::Low);
670    }
671}