1use tensorlogic_ir::EinsumGraph;
4
5use crate::batch::BatchResult;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum StreamingMode {
10 None,
12 FixedChunk(usize),
14 DynamicChunk { target_memory_mb: usize },
16 Adaptive { initial_chunk: usize },
18}
19
20#[derive(Debug, Clone)]
22pub struct StreamingConfig {
23 pub mode: StreamingMode,
24 pub prefetch_chunks: usize,
25 pub overlap_compute_io: bool,
26 pub checkpoint_interval: Option<usize>,
27}
28
29impl StreamingConfig {
30 pub fn new(mode: StreamingMode) -> Self {
31 StreamingConfig {
32 mode,
33 prefetch_chunks: 1,
34 overlap_compute_io: true,
35 checkpoint_interval: None,
36 }
37 }
38
39 pub fn with_prefetch(mut self, num_chunks: usize) -> Self {
40 self.prefetch_chunks = num_chunks;
41 self
42 }
43
44 pub fn with_checkpointing(mut self, interval: usize) -> Self {
45 self.checkpoint_interval = Some(interval);
46 self
47 }
48
49 pub fn disable_overlap(mut self) -> Self {
50 self.overlap_compute_io = false;
51 self
52 }
53}
54
55impl Default for StreamingConfig {
56 fn default() -> Self {
57 Self::new(StreamingMode::None)
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct ChunkMetadata {
64 pub chunk_id: usize,
65 pub start_idx: usize,
66 pub end_idx: usize,
67 pub size: usize,
68 pub is_last: bool,
69}
70
71impl ChunkMetadata {
72 pub fn new(chunk_id: usize, start_idx: usize, end_idx: usize, total_size: usize) -> Self {
73 let size = end_idx - start_idx;
74 let is_last = end_idx >= total_size;
75 ChunkMetadata {
76 chunk_id,
77 start_idx,
78 end_idx,
79 size,
80 is_last,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct StreamResult<T> {
88 pub outputs: Vec<T>,
89 pub metadata: ChunkMetadata,
90 pub processing_time_ms: f64,
91}
92
93impl<T> StreamResult<T> {
94 pub fn new(outputs: Vec<T>, metadata: ChunkMetadata, processing_time_ms: f64) -> Self {
95 StreamResult {
96 outputs,
97 metadata,
98 processing_time_ms,
99 }
100 }
101
102 pub fn throughput_items_per_sec(&self) -> f64 {
103 if self.processing_time_ms > 0.0 {
104 (self.metadata.size as f64) / (self.processing_time_ms / 1000.0)
105 } else {
106 0.0
107 }
108 }
109}
110
111pub trait TlStreamingExecutor {
113 type Tensor;
114 type Error;
115
116 fn execute_stream(
118 &mut self,
119 graph: &EinsumGraph,
120 input_stream: Vec<Vec<Vec<Self::Tensor>>>,
121 config: &StreamingConfig,
122 ) -> Result<Vec<StreamResult<Self::Tensor>>, Self::Error>;
123
124 fn execute_chunk(
126 &mut self,
127 graph: &EinsumGraph,
128 chunk_inputs: Vec<Self::Tensor>,
129 metadata: &ChunkMetadata,
130 ) -> Result<StreamResult<Self::Tensor>, Self::Error>;
131
132 fn recommend_chunk_size(&self, graph: &EinsumGraph, available_memory_mb: usize) -> usize {
134 let _ = (graph, available_memory_mb);
135 32 }
137
138 fn estimate_chunk_memory(&self, graph: &EinsumGraph, chunk_size: usize) -> usize {
140 let _ = (graph, chunk_size);
141 chunk_size * 1024 * 1024 }
143}
144
145pub struct ChunkIterator {
147 total_size: usize,
148 chunk_size: usize,
149 current_chunk: usize,
150}
151
152impl ChunkIterator {
153 pub fn new(total_size: usize, chunk_size: usize) -> Self {
154 ChunkIterator {
155 total_size,
156 chunk_size,
157 current_chunk: 0,
158 }
159 }
160
161 pub fn from_config(total_size: usize, config: &StreamingConfig) -> Self {
162 let chunk_size = match config.mode {
163 StreamingMode::None => total_size,
164 StreamingMode::FixedChunk(size) => size,
165 StreamingMode::DynamicChunk { target_memory_mb } => {
166 (target_memory_mb).max(1)
168 }
169 StreamingMode::Adaptive { initial_chunk } => initial_chunk,
170 };
171
172 ChunkIterator::new(total_size, chunk_size)
173 }
174
175 pub fn num_chunks(&self) -> usize {
176 self.total_size.div_ceil(self.chunk_size)
177 }
178
179 pub fn current_chunk(&self) -> usize {
180 self.current_chunk
181 }
182}
183
184impl Iterator for ChunkIterator {
185 type Item = ChunkMetadata;
186
187 fn next(&mut self) -> Option<Self::Item> {
188 let start_idx = self.current_chunk * self.chunk_size;
189 if start_idx >= self.total_size {
190 return None;
191 }
192
193 let end_idx = (start_idx + self.chunk_size).min(self.total_size);
194 let metadata = ChunkMetadata::new(self.current_chunk, start_idx, end_idx, self.total_size);
195
196 self.current_chunk += 1;
197 Some(metadata)
198 }
199}
200
201pub struct StreamProcessor {
203 config: StreamingConfig,
204}
205
206impl StreamProcessor {
207 pub fn new(config: StreamingConfig) -> Self {
208 StreamProcessor { config }
209 }
210
211 pub fn split_batch<T: Clone>(&self, batch: &BatchResult<T>) -> Vec<(ChunkMetadata, Vec<T>)> {
213 let total_size = batch.len();
214 let iter = ChunkIterator::from_config(total_size, &self.config);
215
216 iter.map(|metadata| {
217 let chunk_data: Vec<T> = batch.outputs[metadata.start_idx..metadata.end_idx].to_vec();
218 (metadata, chunk_data)
219 })
220 .collect()
221 }
222
223 pub fn merge_results<T>(results: Vec<StreamResult<T>>) -> BatchResult<T> {
225 let total_size: usize = results.iter().map(|r| r.outputs.len()).sum();
226 let mut outputs = Vec::with_capacity(total_size);
227
228 for result in results {
229 outputs.extend(result.outputs);
230 }
231
232 BatchResult::new(outputs)
233 }
234
235 pub fn adaptive_chunk_size(&self, results: &[StreamResult<impl Clone>]) -> usize {
237 if results.is_empty() {
238 return 32; }
240
241 let avg_throughput: f64 = results
243 .iter()
244 .map(|r| r.throughput_items_per_sec())
245 .sum::<f64>()
246 / results.len() as f64;
247
248 let target_time_ms = 100.0;
251 let items_per_chunk = (avg_throughput * target_time_ms / 1000.0) as usize;
252
253 items_per_chunk.clamp(1, 1000) }
255
256 pub fn config(&self) -> &StreamingConfig {
257 &self.config
258 }
259}
260
261impl Default for StreamProcessor {
262 fn default() -> Self {
263 Self::new(StreamingConfig::default())
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_streaming_config() {
273 let config = StreamingConfig::new(StreamingMode::FixedChunk(64))
274 .with_prefetch(2)
275 .with_checkpointing(100);
276
277 assert_eq!(config.mode, StreamingMode::FixedChunk(64));
278 assert_eq!(config.prefetch_chunks, 2);
279 assert_eq!(config.checkpoint_interval, Some(100));
280 }
281
282 #[test]
283 fn test_chunk_metadata() {
284 let metadata = ChunkMetadata::new(0, 0, 32, 100);
285 assert_eq!(metadata.chunk_id, 0);
286 assert_eq!(metadata.size, 32);
287 assert!(!metadata.is_last);
288
289 let last_metadata = ChunkMetadata::new(3, 96, 100, 100);
290 assert!(last_metadata.is_last);
291 }
292
293 #[test]
294 fn test_stream_result() {
295 let metadata = ChunkMetadata::new(0, 0, 32, 100);
296 let result: StreamResult<i32> = StreamResult::new(vec![1, 2, 3], metadata, 100.0);
297
298 assert_eq!(result.outputs.len(), 3);
299 let throughput = result.throughput_items_per_sec();
300 assert!(throughput > 0.0);
301 }
302
303 #[test]
304 fn test_chunk_iterator() {
305 let iter = ChunkIterator::new(100, 32);
306 assert_eq!(iter.num_chunks(), 4); let chunks: Vec<_> = iter.collect();
309 assert_eq!(chunks.len(), 4);
310 assert_eq!(chunks[0].size, 32);
311 assert_eq!(chunks[3].size, 4);
312 assert!(chunks[3].is_last);
313 }
314
315 #[test]
316 fn test_chunk_iterator_from_config() {
317 let config = StreamingConfig::new(StreamingMode::FixedChunk(25));
318 let iter = ChunkIterator::from_config(100, &config);
319
320 assert_eq!(iter.chunk_size, 25);
321 assert_eq!(iter.num_chunks(), 4);
322 }
323
324 #[test]
325 fn test_stream_processor_split() {
326 let batch = BatchResult::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
327 let config = StreamingConfig::new(StreamingMode::FixedChunk(3));
328 let processor = StreamProcessor::new(config);
329
330 let chunks = processor.split_batch(&batch);
331 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].1, vec![1, 2, 3]);
334 assert_eq!(chunks[1].1, vec![4, 5, 6]);
335 assert_eq!(chunks[2].1, vec![7, 8, 9]);
336 assert_eq!(chunks[3].1, vec![10]);
337 }
338
339 #[test]
340 fn test_stream_processor_merge() {
341 let metadata1 = ChunkMetadata::new(0, 0, 3, 10);
342 let metadata2 = ChunkMetadata::new(1, 3, 6, 10);
343 let metadata3 = ChunkMetadata::new(2, 6, 10, 10);
344
345 let results = vec![
346 StreamResult::new(vec![1, 2, 3], metadata1, 10.0),
347 StreamResult::new(vec![4, 5, 6], metadata2, 10.0),
348 StreamResult::new(vec![7, 8, 9, 10], metadata3, 10.0),
349 ];
350
351 let batch = StreamProcessor::merge_results(results);
352 assert_eq!(batch.len(), 10);
353 assert_eq!(batch.outputs, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
354 }
355
356 #[test]
357 fn test_adaptive_chunk_size() {
358 let processor = StreamProcessor::default();
359
360 let metadata = ChunkMetadata::new(0, 0, 100, 1000);
361 let results = vec![
362 StreamResult::new(vec![(); 100], metadata.clone(), 50.0), StreamResult::new(vec![(); 100], metadata.clone(), 100.0), StreamResult::new(vec![(); 100], metadata, 75.0), ];
366
367 let chunk_size = processor.adaptive_chunk_size(&results);
368 assert!(chunk_size > 0);
369 assert!(chunk_size <= 1000); }
371
372 #[test]
373 fn test_streaming_modes() {
374 assert_eq!(StreamingMode::None, StreamingConfig::default().mode);
375
376 let fixed = StreamingMode::FixedChunk(64);
377 assert_eq!(fixed, StreamingMode::FixedChunk(64));
378
379 let dynamic = StreamingMode::DynamicChunk {
380 target_memory_mb: 512,
381 };
382 match dynamic {
383 StreamingMode::DynamicChunk { target_memory_mb } => {
384 assert_eq!(target_memory_mb, 512);
385 }
386 _ => panic!("Wrong mode"),
387 }
388 }
389}
390
391#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum BackpressureStrategy {
399 Block,
401 DropOldest,
403 DropNewest,
405 ErrorOnFull,
407}
408
409#[derive(Debug, Clone)]
411pub struct BackpressureConfig {
412 pub max_buffered_chunks: usize,
414 pub high_watermark: f64,
417 pub low_watermark: f64,
420 pub strategy: BackpressureStrategy,
422}
423
424impl BackpressureConfig {
425 pub fn new(max_buffered: usize) -> Self {
428 BackpressureConfig {
429 max_buffered_chunks: max_buffered,
430 high_watermark: 0.8,
431 low_watermark: 0.2,
432 strategy: BackpressureStrategy::Block,
433 }
434 }
435
436 pub fn with_watermarks(mut self, high: f64, low: f64) -> Self {
438 self.high_watermark = high;
439 self.low_watermark = low;
440 self
441 }
442
443 pub fn with_strategy(mut self, strategy: BackpressureStrategy) -> Self {
445 self.strategy = strategy;
446 self
447 }
448
449 pub fn is_above_high_watermark(&self, current_buffered: usize) -> bool {
451 let threshold = (self.max_buffered_chunks as f64 * self.high_watermark) as usize;
452 current_buffered > threshold
453 }
454
455 pub fn is_below_low_watermark(&self, current_buffered: usize) -> bool {
457 let threshold = (self.max_buffered_chunks as f64 * self.low_watermark) as usize;
458 current_buffered < threshold
459 }
460
461 pub fn should_apply_backpressure(&self, current_buffered: usize) -> bool {
463 self.is_above_high_watermark(current_buffered)
464 }
465}
466
467#[derive(Debug, Clone)]
469pub struct WatermarkConfig {
470 pub max_out_of_order_ms: u64,
472 pub idle_timeout_ms: Option<u64>,
474 pub drop_late_events: bool,
476}
477
478impl WatermarkConfig {
479 pub fn new(max_out_of_order_ms: u64) -> Self {
482 WatermarkConfig {
483 max_out_of_order_ms,
484 idle_timeout_ms: None,
485 drop_late_events: false,
486 }
487 }
488
489 pub fn with_idle_timeout(mut self, timeout_ms: u64) -> Self {
491 self.idle_timeout_ms = Some(timeout_ms);
492 self
493 }
494
495 pub fn with_drop_late(mut self, drop: bool) -> Self {
497 self.drop_late_events = drop;
498 self
499 }
500
501 pub fn current_watermark(&self, max_event_time_ms: u64) -> u64 {
505 max_event_time_ms.saturating_sub(self.max_out_of_order_ms)
506 }
507
508 pub fn is_late(&self, event_time_ms: u64, watermark_ms: u64) -> bool {
510 event_time_ms < watermark_ms
511 }
512}
513
514#[derive(Debug, Clone)]
516pub struct StreamingConfigV2 {
517 pub base: StreamingConfig,
519 pub backpressure: Option<BackpressureConfig>,
521 pub watermark: Option<WatermarkConfig>,
523}
524
525impl StreamingConfigV2 {
526 pub fn new(base: StreamingConfig) -> Self {
528 StreamingConfigV2 {
529 base,
530 backpressure: None,
531 watermark: None,
532 }
533 }
534
535 pub fn with_backpressure(mut self, config: BackpressureConfig) -> Self {
537 self.backpressure = Some(config);
538 self
539 }
540
541 pub fn with_watermark(mut self, config: WatermarkConfig) -> Self {
543 self.watermark = Some(config);
544 self
545 }
546
547 pub fn should_apply_backpressure(&self, current_buffered: usize) -> bool {
549 self.backpressure
550 .as_ref()
551 .is_some_and(|bp| bp.should_apply_backpressure(current_buffered))
552 }
553
554 pub fn is_late_event(&self, event_time_ms: u64, watermark_ms: u64) -> bool {
556 self.watermark
557 .as_ref()
558 .is_some_and(|wm| wm.is_late(event_time_ms, watermark_ms))
559 }
560}
561
562impl Default for StreamingConfigV2 {
563 fn default() -> Self {
564 Self::new(StreamingConfig::default())
565 }
566}
567
568#[derive(Debug, Clone, Default)]
570pub struct StreamingStats {
571 pub chunks_processed: usize,
573 pub chunks_dropped: usize,
575 pub backpressure_events: usize,
577 pub late_events_dropped: usize,
579 pub total_processing_time_ms: u64,
581 pub total_elements_processed: usize,
583}
584
585impl StreamingStats {
586 pub fn average_latency_ms(&self) -> f64 {
589 if self.chunks_processed == 0 {
590 return 0.0;
591 }
592 self.total_processing_time_ms as f64 / self.chunks_processed as f64
593 }
594
595 pub fn drop_rate(&self) -> f64 {
598 let total = self.chunks_processed + self.chunks_dropped;
599 if total == 0 {
600 return 0.0;
601 }
602 self.chunks_dropped as f64 / total as f64
603 }
604
605 pub fn throughput_chunks_per_sec(&self) -> f64 {
608 if self.total_processing_time_ms == 0 {
609 return 0.0;
610 }
611 self.chunks_processed as f64 / (self.total_processing_time_ms as f64 / 1000.0)
612 }
613
614 pub fn merge(&mut self, other: &StreamingStats) {
616 self.chunks_processed += other.chunks_processed;
617 self.chunks_dropped += other.chunks_dropped;
618 self.backpressure_events += other.backpressure_events;
619 self.late_events_dropped += other.late_events_dropped;
620 self.total_processing_time_ms += other.total_processing_time_ms;
621 self.total_elements_processed += other.total_elements_processed;
622 }
623}
624
625#[cfg(test)]
630mod v2_tests {
631 use super::*;
632
633 #[test]
636 fn test_backpressure_config_new() {
637 let cfg = BackpressureConfig::new(100);
638 assert_eq!(cfg.max_buffered_chunks, 100);
639 assert!((cfg.high_watermark - 0.8).abs() < f64::EPSILON);
640 assert!((cfg.low_watermark - 0.2).abs() < f64::EPSILON);
641 assert_eq!(cfg.strategy, BackpressureStrategy::Block);
642 }
643
644 #[test]
645 fn test_backpressure_above_high_watermark() {
646 let cfg = BackpressureConfig::new(100); assert!(cfg.is_above_high_watermark(81));
649 assert!(!cfg.is_above_high_watermark(80));
651 assert!(!cfg.is_above_high_watermark(0));
653 }
654
655 #[test]
656 fn test_backpressure_below_low_watermark() {
657 let cfg = BackpressureConfig::new(100); assert!(cfg.is_below_low_watermark(19));
660 assert!(!cfg.is_below_low_watermark(20));
662 assert!(!cfg.is_below_low_watermark(100));
664 }
665
666 #[test]
667 fn test_backpressure_between_watermarks() {
668 let cfg = BackpressureConfig::new(100);
669 assert!(!cfg.should_apply_backpressure(50));
671 assert!(cfg.should_apply_backpressure(81));
673 }
674
675 #[test]
676 fn test_backpressure_strategy_variants() {
677 let block = BackpressureStrategy::Block;
678 let drop_oldest = BackpressureStrategy::DropOldest;
679 let drop_newest = BackpressureStrategy::DropNewest;
680 let error = BackpressureStrategy::ErrorOnFull;
681
682 assert_ne!(drop_oldest, block);
684 assert_ne!(drop_newest, block);
685 assert_ne!(error, block);
686 assert_ne!(drop_oldest, drop_newest);
687
688 let cfg = BackpressureConfig::new(10).with_strategy(BackpressureStrategy::DropOldest);
689 assert_eq!(cfg.strategy, drop_oldest);
690 let _ = error; }
692
693 #[test]
696 fn test_watermark_config_new() {
697 let wm = WatermarkConfig::new(100);
698 assert_eq!(wm.max_out_of_order_ms, 100);
699 assert_eq!(wm.idle_timeout_ms, None);
700 assert!(!wm.drop_late_events);
701 }
702
703 #[test]
704 fn test_watermark_current_watermark_calculation() {
705 let wm = WatermarkConfig::new(100);
706 assert_eq!(wm.current_watermark(500), 400);
707
708 let wm2 = WatermarkConfig::new(1000);
710 assert_eq!(wm2.current_watermark(500), 0);
711 }
712
713 #[test]
714 fn test_watermark_is_late_event() {
715 let wm = WatermarkConfig::new(100);
716 assert!(wm.is_late(300, 400));
718 assert!(!wm.is_late(400, 400));
720 assert!(!wm.is_late(500, 400));
722 }
723
724 #[test]
725 fn test_watermark_with_idle_timeout() {
726 let wm = WatermarkConfig::new(100).with_idle_timeout(5000);
727 assert_eq!(wm.idle_timeout_ms, Some(5000));
728 assert_eq!(wm.max_out_of_order_ms, 100);
729 }
730
731 #[test]
734 fn test_streaming_stats_default() {
735 let stats = StreamingStats::default();
736 assert_eq!(stats.chunks_processed, 0);
737 assert_eq!(stats.chunks_dropped, 0);
738 assert!((stats.average_latency_ms() - 0.0).abs() < f64::EPSILON);
739 assert!((stats.drop_rate() - 0.0).abs() < f64::EPSILON);
740 assert!((stats.throughput_chunks_per_sec() - 0.0).abs() < f64::EPSILON);
741 }
742
743 #[test]
744 fn test_streaming_stats_drop_rate() {
745 let stats = StreamingStats {
746 chunks_processed: 9,
747 chunks_dropped: 1,
748 ..Default::default()
749 };
750 assert!((stats.drop_rate() - 0.1).abs() < 1e-9);
752 }
753
754 #[test]
755 fn test_streaming_stats_merge() {
756 let mut a = StreamingStats {
757 chunks_processed: 10,
758 chunks_dropped: 2,
759 backpressure_events: 1,
760 late_events_dropped: 3,
761 total_processing_time_ms: 500,
762 total_elements_processed: 100,
763 };
764 let b = StreamingStats {
765 chunks_processed: 5,
766 chunks_dropped: 1,
767 backpressure_events: 2,
768 late_events_dropped: 0,
769 total_processing_time_ms: 250,
770 total_elements_processed: 50,
771 };
772 a.merge(&b);
773 assert_eq!(a.chunks_processed, 15);
774 assert_eq!(a.chunks_dropped, 3);
775 assert_eq!(a.backpressure_events, 3);
776 assert_eq!(a.late_events_dropped, 3);
777 assert_eq!(a.total_processing_time_ms, 750);
778 assert_eq!(a.total_elements_processed, 150);
779 }
780
781 #[test]
784 fn test_streaming_config_v2_new() {
785 let cfg = StreamingConfigV2::new(StreamingConfig::default());
786 assert!(cfg.backpressure.is_none());
787 assert!(cfg.watermark.is_none());
788 }
789
790 #[test]
791 fn test_streaming_config_v2_with_backpressure() {
792 let cfg_none = StreamingConfigV2::new(StreamingConfig::default());
794 assert!(!cfg_none.should_apply_backpressure(0));
795 assert!(!cfg_none.should_apply_backpressure(usize::MAX));
796
797 let bp = BackpressureConfig::new(100);
799 let cfg = StreamingConfigV2::new(StreamingConfig::default()).with_backpressure(bp);
800 assert!(!cfg.should_apply_backpressure(50));
801 assert!(cfg.should_apply_backpressure(81));
802 }
803
804 #[test]
805 fn test_streaming_config_v2_combined() {
806 let bp = BackpressureConfig::new(50);
807 let wm = WatermarkConfig::new(200);
808 let cfg = StreamingConfigV2::new(StreamingConfig::default())
809 .with_backpressure(bp)
810 .with_watermark(wm);
811 assert!(cfg.backpressure.is_some());
812 assert!(cfg.watermark.is_some());
813
814 assert!(cfg.should_apply_backpressure(41));
816 assert!(cfg.is_late_event(100, 300));
818 assert!(!cfg.is_late_event(400, 300));
820 }
821}