Skip to main content

tensorlogic_infer/
streaming.rs

1//! Streaming execution support for large graphs and datasets.
2
3use tensorlogic_ir::EinsumGraph;
4
5use crate::batch::BatchResult;
6
7/// Streaming execution mode for handling large datasets
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum StreamingMode {
10    /// Process all at once (no streaming)
11    None,
12    /// Stream inputs in fixed-size chunks
13    FixedChunk(usize),
14    /// Stream with dynamic chunk sizing based on memory
15    DynamicChunk { target_memory_mb: usize },
16    /// Stream with adaptive chunking based on performance
17    Adaptive { initial_chunk: usize },
18}
19
20/// Configuration for streaming execution
21#[derive(Debug, Clone)]
22pub struct StreamingConfig {
23    pub mode: StreamingMode,
24    pub prefetch_chunks: usize,
25    pub overlap_compute_io: bool,
26    pub checkpoint_interval: Option<usize>,
27}
28
29impl StreamingConfig {
30    pub fn new(mode: StreamingMode) -> Self {
31        StreamingConfig {
32            mode,
33            prefetch_chunks: 1,
34            overlap_compute_io: true,
35            checkpoint_interval: None,
36        }
37    }
38
39    pub fn with_prefetch(mut self, num_chunks: usize) -> Self {
40        self.prefetch_chunks = num_chunks;
41        self
42    }
43
44    pub fn with_checkpointing(mut self, interval: usize) -> Self {
45        self.checkpoint_interval = Some(interval);
46        self
47    }
48
49    pub fn disable_overlap(mut self) -> Self {
50        self.overlap_compute_io = false;
51        self
52    }
53}
54
55impl Default for StreamingConfig {
56    fn default() -> Self {
57        Self::new(StreamingMode::None)
58    }
59}
60
61/// Stream chunk metadata
62#[derive(Debug, Clone)]
63pub struct ChunkMetadata {
64    pub chunk_id: usize,
65    pub start_idx: usize,
66    pub end_idx: usize,
67    pub size: usize,
68    pub is_last: bool,
69}
70
71impl ChunkMetadata {
72    pub fn new(chunk_id: usize, start_idx: usize, end_idx: usize, total_size: usize) -> Self {
73        let size = end_idx - start_idx;
74        let is_last = end_idx >= total_size;
75        ChunkMetadata {
76            chunk_id,
77            start_idx,
78            end_idx,
79            size,
80            is_last,
81        }
82    }
83}
84
85/// Streaming execution result with chunk information
86#[derive(Debug, Clone)]
87pub struct StreamResult<T> {
88    pub outputs: Vec<T>,
89    pub metadata: ChunkMetadata,
90    pub processing_time_ms: f64,
91}
92
93impl<T> StreamResult<T> {
94    pub fn new(outputs: Vec<T>, metadata: ChunkMetadata, processing_time_ms: f64) -> Self {
95        StreamResult {
96            outputs,
97            metadata,
98            processing_time_ms,
99        }
100    }
101
102    pub fn throughput_items_per_sec(&self) -> f64 {
103        if self.processing_time_ms > 0.0 {
104            (self.metadata.size as f64) / (self.processing_time_ms / 1000.0)
105        } else {
106            0.0
107        }
108    }
109}
110
111/// Trait for executors that support streaming execution
112pub trait TlStreamingExecutor {
113    type Tensor;
114    type Error;
115
116    /// Execute graph on a stream of input chunks
117    fn execute_stream(
118        &mut self,
119        graph: &EinsumGraph,
120        input_stream: Vec<Vec<Vec<Self::Tensor>>>,
121        config: &StreamingConfig,
122    ) -> Result<Vec<StreamResult<Self::Tensor>>, Self::Error>;
123
124    /// Execute graph on a single chunk with metadata
125    fn execute_chunk(
126        &mut self,
127        graph: &EinsumGraph,
128        chunk_inputs: Vec<Self::Tensor>,
129        metadata: &ChunkMetadata,
130    ) -> Result<StreamResult<Self::Tensor>, Self::Error>;
131
132    /// Get recommended chunk size based on available memory
133    fn recommend_chunk_size(&self, graph: &EinsumGraph, available_memory_mb: usize) -> usize {
134        let _ = (graph, available_memory_mb);
135        32 // Default recommendation
136    }
137
138    /// Estimate memory usage per chunk
139    fn estimate_chunk_memory(&self, graph: &EinsumGraph, chunk_size: usize) -> usize {
140        let _ = (graph, chunk_size);
141        chunk_size * 1024 * 1024 // Default: 1MB per item
142    }
143}
144
145/// Chunk iterator for breaking large batches into streams
146pub struct ChunkIterator {
147    total_size: usize,
148    chunk_size: usize,
149    current_chunk: usize,
150}
151
152impl ChunkIterator {
153    pub fn new(total_size: usize, chunk_size: usize) -> Self {
154        ChunkIterator {
155            total_size,
156            chunk_size,
157            current_chunk: 0,
158        }
159    }
160
161    pub fn from_config(total_size: usize, config: &StreamingConfig) -> Self {
162        let chunk_size = match config.mode {
163            StreamingMode::None => total_size,
164            StreamingMode::FixedChunk(size) => size,
165            StreamingMode::DynamicChunk { target_memory_mb } => {
166                // Estimate: ~1MB per item, adjust based on target memory
167                (target_memory_mb).max(1)
168            }
169            StreamingMode::Adaptive { initial_chunk } => initial_chunk,
170        };
171
172        ChunkIterator::new(total_size, chunk_size)
173    }
174
175    pub fn num_chunks(&self) -> usize {
176        self.total_size.div_ceil(self.chunk_size)
177    }
178
179    pub fn current_chunk(&self) -> usize {
180        self.current_chunk
181    }
182}
183
184impl Iterator for ChunkIterator {
185    type Item = ChunkMetadata;
186
187    fn next(&mut self) -> Option<Self::Item> {
188        let start_idx = self.current_chunk * self.chunk_size;
189        if start_idx >= self.total_size {
190            return None;
191        }
192
193        let end_idx = (start_idx + self.chunk_size).min(self.total_size);
194        let metadata = ChunkMetadata::new(self.current_chunk, start_idx, end_idx, self.total_size);
195
196        self.current_chunk += 1;
197        Some(metadata)
198    }
199}
200
201/// Stream processor for handling streaming execution
202pub struct StreamProcessor {
203    config: StreamingConfig,
204}
205
206impl StreamProcessor {
207    pub fn new(config: StreamingConfig) -> Self {
208        StreamProcessor { config }
209    }
210
211    /// Split batch result into chunks based on configuration
212    pub fn split_batch<T: Clone>(&self, batch: &BatchResult<T>) -> Vec<(ChunkMetadata, Vec<T>)> {
213        let total_size = batch.len();
214        let iter = ChunkIterator::from_config(total_size, &self.config);
215
216        iter.map(|metadata| {
217            let chunk_data: Vec<T> = batch.outputs[metadata.start_idx..metadata.end_idx].to_vec();
218            (metadata, chunk_data)
219        })
220        .collect()
221    }
222
223    /// Merge stream results back into a single batch
224    pub fn merge_results<T>(results: Vec<StreamResult<T>>) -> BatchResult<T> {
225        let total_size: usize = results.iter().map(|r| r.outputs.len()).sum();
226        let mut outputs = Vec::with_capacity(total_size);
227
228        for result in results {
229            outputs.extend(result.outputs);
230        }
231
232        BatchResult::new(outputs)
233    }
234
235    /// Calculate adaptive chunk size based on performance metrics
236    pub fn adaptive_chunk_size(&self, results: &[StreamResult<impl Clone>]) -> usize {
237        if results.is_empty() {
238            return 32; // Default
239        }
240
241        // Calculate average throughput
242        let avg_throughput: f64 = results
243            .iter()
244            .map(|r| r.throughput_items_per_sec())
245            .sum::<f64>()
246            / results.len() as f64;
247
248        // Adjust chunk size based on throughput
249        // Goal: maintain ~100ms per chunk for good responsiveness
250        let target_time_ms = 100.0;
251        let items_per_chunk = (avg_throughput * target_time_ms / 1000.0) as usize;
252
253        items_per_chunk.clamp(1, 1000) // Clamp between 1 and 1000
254    }
255
256    pub fn config(&self) -> &StreamingConfig {
257        &self.config
258    }
259}
260
261impl Default for StreamProcessor {
262    fn default() -> Self {
263        Self::new(StreamingConfig::default())
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_streaming_config() {
273        let config = StreamingConfig::new(StreamingMode::FixedChunk(64))
274            .with_prefetch(2)
275            .with_checkpointing(100);
276
277        assert_eq!(config.mode, StreamingMode::FixedChunk(64));
278        assert_eq!(config.prefetch_chunks, 2);
279        assert_eq!(config.checkpoint_interval, Some(100));
280    }
281
282    #[test]
283    fn test_chunk_metadata() {
284        let metadata = ChunkMetadata::new(0, 0, 32, 100);
285        assert_eq!(metadata.chunk_id, 0);
286        assert_eq!(metadata.size, 32);
287        assert!(!metadata.is_last);
288
289        let last_metadata = ChunkMetadata::new(3, 96, 100, 100);
290        assert!(last_metadata.is_last);
291    }
292
293    #[test]
294    fn test_stream_result() {
295        let metadata = ChunkMetadata::new(0, 0, 32, 100);
296        let result: StreamResult<i32> = StreamResult::new(vec![1, 2, 3], metadata, 100.0);
297
298        assert_eq!(result.outputs.len(), 3);
299        let throughput = result.throughput_items_per_sec();
300        assert!(throughput > 0.0);
301    }
302
303    #[test]
304    fn test_chunk_iterator() {
305        let iter = ChunkIterator::new(100, 32);
306        assert_eq!(iter.num_chunks(), 4); // 32, 32, 32, 4
307
308        let chunks: Vec<_> = iter.collect();
309        assert_eq!(chunks.len(), 4);
310        assert_eq!(chunks[0].size, 32);
311        assert_eq!(chunks[3].size, 4);
312        assert!(chunks[3].is_last);
313    }
314
315    #[test]
316    fn test_chunk_iterator_from_config() {
317        let config = StreamingConfig::new(StreamingMode::FixedChunk(25));
318        let iter = ChunkIterator::from_config(100, &config);
319
320        assert_eq!(iter.chunk_size, 25);
321        assert_eq!(iter.num_chunks(), 4);
322    }
323
324    #[test]
325    fn test_stream_processor_split() {
326        let batch = BatchResult::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
327        let config = StreamingConfig::new(StreamingMode::FixedChunk(3));
328        let processor = StreamProcessor::new(config);
329
330        let chunks = processor.split_batch(&batch);
331        assert_eq!(chunks.len(), 4); // 3, 3, 3, 1
332
333        assert_eq!(chunks[0].1, vec![1, 2, 3]);
334        assert_eq!(chunks[1].1, vec![4, 5, 6]);
335        assert_eq!(chunks[2].1, vec![7, 8, 9]);
336        assert_eq!(chunks[3].1, vec![10]);
337    }
338
339    #[test]
340    fn test_stream_processor_merge() {
341        let metadata1 = ChunkMetadata::new(0, 0, 3, 10);
342        let metadata2 = ChunkMetadata::new(1, 3, 6, 10);
343        let metadata3 = ChunkMetadata::new(2, 6, 10, 10);
344
345        let results = vec![
346            StreamResult::new(vec![1, 2, 3], metadata1, 10.0),
347            StreamResult::new(vec![4, 5, 6], metadata2, 10.0),
348            StreamResult::new(vec![7, 8, 9, 10], metadata3, 10.0),
349        ];
350
351        let batch = StreamProcessor::merge_results(results);
352        assert_eq!(batch.len(), 10);
353        assert_eq!(batch.outputs, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
354    }
355
356    #[test]
357    fn test_adaptive_chunk_size() {
358        let processor = StreamProcessor::default();
359
360        let metadata = ChunkMetadata::new(0, 0, 100, 1000);
361        let results = vec![
362            StreamResult::new(vec![(); 100], metadata.clone(), 50.0), // 2000 items/sec
363            StreamResult::new(vec![(); 100], metadata.clone(), 100.0), // 1000 items/sec
364            StreamResult::new(vec![(); 100], metadata, 75.0),         // 1333 items/sec
365        ];
366
367        let chunk_size = processor.adaptive_chunk_size(&results);
368        assert!(chunk_size > 0);
369        assert!(chunk_size <= 1000); // Within clamp range
370    }
371
372    #[test]
373    fn test_streaming_modes() {
374        assert_eq!(StreamingMode::None, StreamingConfig::default().mode);
375
376        let fixed = StreamingMode::FixedChunk(64);
377        assert_eq!(fixed, StreamingMode::FixedChunk(64));
378
379        let dynamic = StreamingMode::DynamicChunk {
380            target_memory_mb: 512,
381        };
382        match dynamic {
383            StreamingMode::DynamicChunk { target_memory_mb } => {
384                assert_eq!(target_memory_mb, 512);
385            }
386            _ => panic!("Wrong mode"),
387        }
388    }
389}