1use crate::core::tensor::WasmTensor;
4use js_sys::{Date, Promise};
5use serde::{Deserialize, Serialize};
6use std::string::String;
7use std::vec::Vec;
8use wasm_bindgen::prelude::*;
9use wasm_bindgen_futures::JsFuture;
10
11#[wasm_bindgen]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum BatchingStrategy {
15 Immediate,
17 FixedSize,
19 Dynamic,
21 Adaptive,
23}
24
25#[wasm_bindgen]
27#[derive(Debug, Clone)]
28pub struct BatchConfig {
29 strategy: BatchingStrategy,
30 max_batch_size: usize,
31 timeout_ms: u32,
32 target_latency_ms: u32,
33 memory_limit_mb: f32,
34 enable_prioritization: bool,
35 enable_preemption: bool,
36}
37
38#[wasm_bindgen]
39impl BatchConfig {
40 #[wasm_bindgen(constructor)]
42 pub fn new(strategy: BatchingStrategy, max_batch_size: usize) -> Self {
43 Self {
44 strategy,
45 max_batch_size,
46 timeout_ms: 100, target_latency_ms: 50, memory_limit_mb: 100.0,
49 enable_prioritization: false,
50 enable_preemption: false,
51 }
52 }
53
54 pub fn real_time() -> Self {
56 Self {
57 strategy: BatchingStrategy::Dynamic,
58 max_batch_size: 4,
59 timeout_ms: 10,
60 target_latency_ms: 20,
61 memory_limit_mb: 50.0,
62 enable_prioritization: true,
63 enable_preemption: true,
64 }
65 }
66
67 pub fn throughput() -> Self {
69 Self {
70 strategy: BatchingStrategy::FixedSize,
71 max_batch_size: 32,
72 timeout_ms: 500,
73 target_latency_ms: 200,
74 memory_limit_mb: 500.0,
75 enable_prioritization: false,
76 enable_preemption: false,
77 }
78 }
79
80 pub fn mobile() -> Self {
82 Self {
83 strategy: BatchingStrategy::Adaptive,
84 max_batch_size: 2,
85 timeout_ms: 50,
86 target_latency_ms: 100,
87 memory_limit_mb: 20.0,
88 enable_prioritization: true,
89 enable_preemption: false,
90 }
91 }
92
93 pub fn set_timeout_ms(&mut self, timeout_ms: u32) {
95 self.timeout_ms = timeout_ms;
96 }
97
98 pub fn set_target_latency_ms(&mut self, latency_ms: u32) {
100 self.target_latency_ms = latency_ms;
101 }
102
103 pub fn set_memory_limit_mb(&mut self, limit_mb: f32) {
105 self.memory_limit_mb = limit_mb;
106 }
107
108 pub fn enable_prioritization(&mut self) {
110 self.enable_prioritization = true;
111 }
112
113 pub fn enable_preemption(&mut self) {
115 self.enable_preemption = true;
116 }
117}
118
119#[wasm_bindgen]
121#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
122pub enum Priority {
123 Low = 0,
124 Normal = 1,
125 High = 2,
126 Critical = 3,
127}
128
129#[derive(Debug, Clone)]
131pub struct BatchRequest {
132 pub id: String,
133 pub input: WasmTensor,
134 pub priority: Priority,
135 pub timestamp: f64,
136 pub timeout_ms: Option<u32>,
137 pub callback: Option<js_sys::Function>,
138}
139
140#[wasm_bindgen]
142pub struct BatchResponse {
143 request_id: String,
144 result: Option<WasmTensor>,
145 error: Option<String>,
146 processing_time_ms: f64,
147 queue_time_ms: f64,
148 batch_size: usize,
149}
150
151#[wasm_bindgen]
152impl BatchResponse {
153 #[wasm_bindgen(getter)]
155 pub fn request_id(&self) -> String {
156 self.request_id.clone()
157 }
158
159 pub fn result(&self) -> Option<WasmTensor> {
161 self.result.clone()
162 }
163
164 #[wasm_bindgen(getter)]
166 pub fn error(&self) -> Option<String> {
167 self.error.clone()
168 }
169
170 #[wasm_bindgen(getter)]
172 pub fn processing_time_ms(&self) -> f64 {
173 self.processing_time_ms
174 }
175
176 #[wasm_bindgen(getter)]
178 pub fn queue_time_ms(&self) -> f64 {
179 self.queue_time_ms
180 }
181
182 #[wasm_bindgen(getter)]
184 pub fn batch_size(&self) -> usize {
185 self.batch_size
186 }
187
188 #[wasm_bindgen(getter)]
190 pub fn is_success(&self) -> bool {
191 self.error.is_none() && self.result.is_some()
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct BatchStats {
198 pub total_requests: usize,
199 pub completed_requests: usize,
200 pub failed_requests: usize,
201 pub average_batch_size: f32,
202 pub average_processing_time_ms: f32,
203 pub average_queue_time_ms: f32,
204 pub throughput_requests_per_second: f32,
205 pub memory_usage_mb: f32,
206}
207
208#[wasm_bindgen]
210pub struct BatchProcessor {
211 config: BatchConfig,
212 pending_requests: Vec<BatchRequest>,
213 #[allow(dead_code)]
214 active_batch: Option<Vec<BatchRequest>>,
215 stats: BatchStats,
216 last_batch_time: f64,
217 adaptive_batch_size: usize,
218 request_counter: usize,
219}
220
221#[wasm_bindgen]
222impl BatchProcessor {
223 #[wasm_bindgen(constructor)]
225 pub fn new(config: BatchConfig) -> Self {
226 let adaptive_batch_size = config.max_batch_size.min(4); Self {
229 config,
230 pending_requests: Vec::new(),
231 active_batch: None,
232 stats: BatchStats {
233 total_requests: 0,
234 completed_requests: 0,
235 failed_requests: 0,
236 average_batch_size: 0.0,
237 average_processing_time_ms: 0.0,
238 average_queue_time_ms: 0.0,
239 throughput_requests_per_second: 0.0,
240 memory_usage_mb: 0.0,
241 },
242 last_batch_time: Date::now(),
243 adaptive_batch_size,
244 request_counter: 0,
245 }
246 }
247
248 pub fn add_request(
250 &mut self,
251 input: WasmTensor,
252 priority: Priority,
253 timeout_ms: Option<u32>,
254 ) -> String {
255 self.request_counter += 1;
256 let request_id = format!("req_{counter}", counter = self.request_counter);
257
258 let request = BatchRequest {
259 id: request_id.clone(),
260 input,
261 priority,
262 timestamp: Date::now(),
263 timeout_ms,
264 callback: None,
265 };
266
267 if self.config.enable_prioritization {
269 let insert_pos = self
270 .pending_requests
271 .iter()
272 .position(|r| r.priority < priority)
273 .unwrap_or(self.pending_requests.len());
274 self.pending_requests.insert(insert_pos, request);
275 } else {
276 self.pending_requests.push(request);
277 }
278
279 self.stats.total_requests += 1;
280
281 web_sys::console::log_1(
282 &format!(
283 "Added request {} to batch queue (priority: {:?})",
284 request_id, priority
285 )
286 .into(),
287 );
288
289 request_id
290 }
291
292 pub async fn process_batch(&mut self) -> Result<Vec<BatchResponse>, JsValue> {
294 if self.pending_requests.is_empty() {
295 return Ok(Vec::new());
296 }
297
298 let batch_size = self.determine_batch_size();
299 let batch_requests = self.extract_batch(batch_size);
300
301 if batch_requests.is_empty() {
302 return Ok(Vec::new());
303 }
304
305 let batch_start_time = Date::now();
306
307 web_sys::console::log_1(
308 &format!(
309 "Processing batch of {len} requests",
310 len = batch_requests.len()
311 )
312 .into(),
313 );
314
315 let batch_inputs = self.combine_inputs(&batch_requests)?;
317
318 let batch_results = self.process_batch_inference(&batch_inputs).await?;
320
321 let processing_time = Date::now() - batch_start_time;
322
323 let responses = self.create_responses(
325 batch_requests,
326 batch_results,
327 processing_time,
328 batch_start_time,
329 );
330
331 self.update_stats(&responses, processing_time);
333
334 if self.config.strategy == BatchingStrategy::Adaptive {
336 self.update_adaptive_batch_size(processing_time, responses.len());
337 }
338
339 self.last_batch_time = Date::now();
340
341 Ok(responses)
342 }
343
344 pub fn is_batch_ready(&self) -> bool {
346 if self.pending_requests.is_empty() {
347 return false;
348 }
349
350 match self.config.strategy {
351 BatchingStrategy::Immediate => true,
352 BatchingStrategy::FixedSize => {
353 self.pending_requests.len() >= self.config.max_batch_size
354 },
355 BatchingStrategy::Dynamic => {
356 let elapsed = Date::now() - self.last_batch_time;
357 elapsed >= self.config.timeout_ms as f64
358 || self.pending_requests.len() >= self.config.max_batch_size
359 },
360 BatchingStrategy::Adaptive => {
361 let elapsed = Date::now() - self.last_batch_time;
362 elapsed >= self.config.timeout_ms as f64
363 || self.pending_requests.len() >= self.adaptive_batch_size
364 },
365 }
366 }
367
368 #[wasm_bindgen(getter)]
370 pub fn queue_length(&self) -> usize {
371 self.pending_requests.len()
372 }
373
374 pub fn get_stats(&self) -> String {
376 format!(
377 "Batch Stats: {} total, {} completed, {} failed, avg batch size: {:.1}, avg processing: {:.1}ms, throughput: {:.1} req/s",
378 self.stats.total_requests,
379 self.stats.completed_requests,
380 self.stats.failed_requests,
381 self.stats.average_batch_size,
382 self.stats.average_processing_time_ms,
383 self.stats.throughput_requests_per_second
384 )
385 }
386
387 pub fn clear_queue(&mut self) {
389 self.pending_requests.clear();
390 web_sys::console::log_1(&"Batch queue cleared".into());
391 }
392
393 pub fn update_config(&mut self, config: BatchConfig) {
395 self.config = config;
396 self.adaptive_batch_size = self.config.max_batch_size.min(4);
397 web_sys::console::log_1(&"Batch configuration updated".into());
398 }
399
400 fn determine_batch_size(&self) -> usize {
403 match self.config.strategy {
404 BatchingStrategy::Immediate => 1,
405 BatchingStrategy::FixedSize => {
406 self.config.max_batch_size.min(self.pending_requests.len())
407 },
408 BatchingStrategy::Dynamic => {
409 let elapsed = Date::now() - self.last_batch_time;
410 if elapsed >= self.config.timeout_ms as f64 {
411 self.pending_requests.len().min(self.config.max_batch_size)
412 } else {
413 self.config.max_batch_size.min(self.pending_requests.len())
414 }
415 },
416 BatchingStrategy::Adaptive => self.adaptive_batch_size.min(self.pending_requests.len()),
417 }
418 }
419
420 fn extract_batch(&mut self, batch_size: usize) -> Vec<BatchRequest> {
421 let actual_size = batch_size.min(self.pending_requests.len());
422 self.pending_requests.drain(0..actual_size).collect()
423 }
424
425 fn combine_inputs(&self, requests: &[BatchRequest]) -> Result<WasmTensor, JsValue> {
426 if requests.is_empty() {
427 return Err("No requests to process".into());
428 }
429
430 if requests.len() == 1 {
431 return Ok(requests[0].input.clone());
432 }
433
434 let first_shape = requests[0].input.shape();
436 let batch_size = requests.len();
437
438 for (i, request) in requests.iter().enumerate().skip(1) {
440 let current_shape = request.input.shape();
441 if current_shape.len() != first_shape.len() {
442 return Err(format!(
443 "Tensor {} has incompatible rank: {} vs {}",
444 i,
445 current_shape.len(),
446 first_shape.len()
447 )
448 .into());
449 }
450
451 for (dim_idx, (¤t_dim, &first_dim)) in
453 current_shape[1..].iter().zip(first_shape[1..].iter()).enumerate()
454 {
455 if current_dim != first_dim {
456 return Err(format!(
457 "Tensor {} has incompatible shape at dimension {}: {} vs {}",
458 i,
459 dim_idx + 1,
460 current_dim,
461 first_dim
462 )
463 .into());
464 }
465 }
466 }
467
468 let mut batched_shape = first_shape.clone();
470 batched_shape[0] = batch_size;
471
472 let total_elements = batched_shape.iter().product::<usize>();
474 let mut batched_data = vec![0.0f32; total_elements];
475
476 let elements_per_batch = first_shape.iter().product::<usize>();
478
479 for (batch_idx, request) in requests.iter().enumerate() {
480 let tensor_data = request.input.data();
481 let start_idx = batch_idx * elements_per_batch;
482 let end_idx = start_idx + elements_per_batch.min(tensor_data.len());
483
484 if end_idx <= batched_data.len() {
485 batched_data[start_idx..end_idx]
486 .copy_from_slice(&tensor_data[..elements_per_batch.min(tensor_data.len())]);
487 }
488 }
489
490 WasmTensor::new(batched_data, batched_shape)
492 }
493
494 async fn process_batch_inference(
495 &self,
496 batch_input: &WasmTensor,
497 ) -> Result<Vec<WasmTensor>, JsValue> {
498 let processing_delay = 10.0 + (self.pending_requests.len() as f64 * 2.0);
500
501 let delay_promise = Promise::new(&mut |resolve, _| {
503 let _timeout_id = web_sys::window()
504 .expect("window should be available in browser context")
505 .set_timeout_with_callback_and_timeout_and_arguments_0(
506 &resolve,
507 processing_delay as i32,
508 )
509 .expect("set_timeout should succeed with valid callback");
510 });
512
513 JsFuture::from(delay_promise).await?;
514
515 let batch_shape = batch_input.shape();
517 let batch_size = batch_shape[0];
518
519 let batch_output = match batch_shape.len() {
522 2 => {
523 let output_features = 10; batch_input.matmul(&WasmTensor::randn(vec![batch_shape[1], output_features])?)?
526 },
527 3 => {
528 let _output_features = batch_shape[2]; batch_input.relu() },
532 _ => {
533 batch_input.relu()
535 },
536 };
537
538 let output_shape = batch_output.shape();
540 let elements_per_batch = output_shape[1..].iter().product::<usize>();
541 let output_data = batch_output.data();
542
543 let mut results = Vec::new();
544 for batch_idx in 0..batch_size {
545 let start_idx = batch_idx * elements_per_batch;
546 let end_idx = start_idx + elements_per_batch;
547
548 if end_idx <= output_data.len() {
549 let batch_data = output_data[start_idx..end_idx].to_vec();
550 let mut individual_shape = output_shape[1..].to_vec();
551 individual_shape.insert(0, 1); results.push(WasmTensor::new(batch_data, individual_shape)?);
554 }
555 }
556
557 if results.len() != batch_size {
558 return Err(format!(
559 "Expected {batch_size} results but got {len}",
560 len = results.len()
561 )
562 .into());
563 }
564
565 Ok(results)
566 }
567
568 fn create_responses(
569 &self,
570 requests: Vec<BatchRequest>,
571 results: Vec<WasmTensor>,
572 processing_time: f64,
573 batch_start_time: f64,
574 ) -> Vec<BatchResponse> {
575 let batch_size = requests.len();
576 requests
577 .into_iter()
578 .zip(results)
579 .map(|(request, result)| {
580 let queue_time = batch_start_time - request.timestamp;
581
582 BatchResponse {
583 request_id: request.id,
584 result: Some(result),
585 error: None,
586 processing_time_ms: processing_time,
587 queue_time_ms: queue_time,
588 batch_size,
589 }
590 })
591 .collect()
592 }
593
594 fn update_stats(&mut self, responses: &[BatchResponse], processing_time: f64) {
595 let successful = responses.iter().filter(|r| r.is_success()).count();
596 let failed = responses.len() - successful;
597
598 self.stats.completed_requests += successful;
599 self.stats.failed_requests += failed;
600
601 let total_completed = self.stats.completed_requests as f32;
603 if total_completed > 0.0 {
604 self.stats.average_batch_size = (self.stats.average_batch_size
605 * (total_completed - successful as f32)
606 + responses.len() as f32)
607 / total_completed;
608
609 self.stats.average_processing_time_ms = (self.stats.average_processing_time_ms
610 * (total_completed - successful as f32)
611 + processing_time as f32)
612 / total_completed;
613
614 if let Some(first_response) = responses.first() {
615 self.stats.average_queue_time_ms = (self.stats.average_queue_time_ms
616 * (total_completed - 1.0)
617 + first_response.queue_time_ms as f32)
618 / total_completed;
619 }
620 }
621
622 if processing_time > 0.0 {
624 self.stats.throughput_requests_per_second =
625 (responses.len() as f32) / (processing_time / 1000.0) as f32;
626 }
627 }
628
629 fn update_adaptive_batch_size(&mut self, processing_time: f64, batch_size: usize) {
630 let target_latency = self.config.target_latency_ms as f64;
631
632 if processing_time > target_latency * 1.5 {
633 self.adaptive_batch_size = (self.adaptive_batch_size - 1).max(1);
635 } else if processing_time < target_latency * 0.7 && batch_size == self.adaptive_batch_size {
636 self.adaptive_batch_size =
638 (self.adaptive_batch_size + 1).min(self.config.max_batch_size);
639 }
640
641 web_sys::console::log_1(
642 &format!(
643 "Adaptive batch size updated to {}",
644 self.adaptive_batch_size
645 )
646 .into(),
647 );
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_batch_config() {
657 let config = BatchConfig::real_time();
658 assert_eq!(config.strategy, BatchingStrategy::Dynamic);
659 assert!(config.enable_prioritization);
660
661 let throughput_config = BatchConfig::throughput();
662 assert_eq!(throughput_config.max_batch_size, 32);
663 }
664
665 #[test]
666 fn test_priority_ordering() {
667 assert!(Priority::Critical > Priority::High);
668 assert!(Priority::High > Priority::Normal);
669 assert!(Priority::Normal > Priority::Low);
670 }
671}