1use tensorlogic_ir::EinsumGraph;
4
5use crate::batch::BatchResult;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum StreamingMode {
10 None,
12 FixedChunk(usize),
14 DynamicChunk { target_memory_mb: usize },
16 Adaptive { initial_chunk: usize },
18}
19
20#[derive(Debug, Clone)]
22pub struct StreamingConfig {
23 pub mode: StreamingMode,
24 pub prefetch_chunks: usize,
25 pub overlap_compute_io: bool,
26 pub checkpoint_interval: Option<usize>,
27}
28
29impl StreamingConfig {
30 pub fn new(mode: StreamingMode) -> Self {
31 StreamingConfig {
32 mode,
33 prefetch_chunks: 1,
34 overlap_compute_io: true,
35 checkpoint_interval: None,
36 }
37 }
38
39 pub fn with_prefetch(mut self, num_chunks: usize) -> Self {
40 self.prefetch_chunks = num_chunks;
41 self
42 }
43
44 pub fn with_checkpointing(mut self, interval: usize) -> Self {
45 self.checkpoint_interval = Some(interval);
46 self
47 }
48
49 pub fn disable_overlap(mut self) -> Self {
50 self.overlap_compute_io = false;
51 self
52 }
53}
54
55impl Default for StreamingConfig {
56 fn default() -> Self {
57 Self::new(StreamingMode::None)
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct ChunkMetadata {
64 pub chunk_id: usize,
65 pub start_idx: usize,
66 pub end_idx: usize,
67 pub size: usize,
68 pub is_last: bool,
69}
70
71impl ChunkMetadata {
72 pub fn new(chunk_id: usize, start_idx: usize, end_idx: usize, total_size: usize) -> Self {
73 let size = end_idx - start_idx;
74 let is_last = end_idx >= total_size;
75 ChunkMetadata {
76 chunk_id,
77 start_idx,
78 end_idx,
79 size,
80 is_last,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct StreamResult<T> {
88 pub outputs: Vec<T>,
89 pub metadata: ChunkMetadata,
90 pub processing_time_ms: f64,
91}
92
93impl<T> StreamResult<T> {
94 pub fn new(outputs: Vec<T>, metadata: ChunkMetadata, processing_time_ms: f64) -> Self {
95 StreamResult {
96 outputs,
97 metadata,
98 processing_time_ms,
99 }
100 }
101
102 pub fn throughput_items_per_sec(&self) -> f64 {
103 if self.processing_time_ms > 0.0 {
104 (self.metadata.size as f64) / (self.processing_time_ms / 1000.0)
105 } else {
106 0.0
107 }
108 }
109}
110
111pub trait TlStreamingExecutor {
113 type Tensor;
114 type Error;
115
116 fn execute_stream(
118 &mut self,
119 graph: &EinsumGraph,
120 input_stream: Vec<Vec<Vec<Self::Tensor>>>,
121 config: &StreamingConfig,
122 ) -> Result<Vec<StreamResult<Self::Tensor>>, Self::Error>;
123
124 fn execute_chunk(
126 &mut self,
127 graph: &EinsumGraph,
128 chunk_inputs: Vec<Self::Tensor>,
129 metadata: &ChunkMetadata,
130 ) -> Result<StreamResult<Self::Tensor>, Self::Error>;
131
132 fn recommend_chunk_size(&self, graph: &EinsumGraph, available_memory_mb: usize) -> usize {
134 let _ = (graph, available_memory_mb);
135 32 }
137
138 fn estimate_chunk_memory(&self, graph: &EinsumGraph, chunk_size: usize) -> usize {
140 let _ = (graph, chunk_size);
141 chunk_size * 1024 * 1024 }
143}
144
145pub struct ChunkIterator {
147 total_size: usize,
148 chunk_size: usize,
149 current_chunk: usize,
150}
151
152impl ChunkIterator {
153 pub fn new(total_size: usize, chunk_size: usize) -> Self {
154 ChunkIterator {
155 total_size,
156 chunk_size,
157 current_chunk: 0,
158 }
159 }
160
161 pub fn from_config(total_size: usize, config: &StreamingConfig) -> Self {
162 let chunk_size = match config.mode {
163 StreamingMode::None => total_size,
164 StreamingMode::FixedChunk(size) => size,
165 StreamingMode::DynamicChunk { target_memory_mb } => {
166 (target_memory_mb).max(1)
168 }
169 StreamingMode::Adaptive { initial_chunk } => initial_chunk,
170 };
171
172 ChunkIterator::new(total_size, chunk_size)
173 }
174
175 pub fn num_chunks(&self) -> usize {
176 self.total_size.div_ceil(self.chunk_size)
177 }
178
179 pub fn current_chunk(&self) -> usize {
180 self.current_chunk
181 }
182}
183
184impl Iterator for ChunkIterator {
185 type Item = ChunkMetadata;
186
187 fn next(&mut self) -> Option<Self::Item> {
188 let start_idx = self.current_chunk * self.chunk_size;
189 if start_idx >= self.total_size {
190 return None;
191 }
192
193 let end_idx = (start_idx + self.chunk_size).min(self.total_size);
194 let metadata = ChunkMetadata::new(self.current_chunk, start_idx, end_idx, self.total_size);
195
196 self.current_chunk += 1;
197 Some(metadata)
198 }
199}
200
201pub struct StreamProcessor {
203 config: StreamingConfig,
204}
205
206impl StreamProcessor {
207 pub fn new(config: StreamingConfig) -> Self {
208 StreamProcessor { config }
209 }
210
211 pub fn split_batch<T: Clone>(&self, batch: &BatchResult<T>) -> Vec<(ChunkMetadata, Vec<T>)> {
213 let total_size = batch.len();
214 let iter = ChunkIterator::from_config(total_size, &self.config);
215
216 iter.map(|metadata| {
217 let chunk_data: Vec<T> = batch.outputs[metadata.start_idx..metadata.end_idx].to_vec();
218 (metadata, chunk_data)
219 })
220 .collect()
221 }
222
223 pub fn merge_results<T>(results: Vec<StreamResult<T>>) -> BatchResult<T> {
225 let total_size: usize = results.iter().map(|r| r.outputs.len()).sum();
226 let mut outputs = Vec::with_capacity(total_size);
227
228 for result in results {
229 outputs.extend(result.outputs);
230 }
231
232 BatchResult::new(outputs)
233 }
234
235 pub fn adaptive_chunk_size(&self, results: &[StreamResult<impl Clone>]) -> usize {
237 if results.is_empty() {
238 return 32; }
240
241 let avg_throughput: f64 = results
243 .iter()
244 .map(|r| r.throughput_items_per_sec())
245 .sum::<f64>()
246 / results.len() as f64;
247
248 let target_time_ms = 100.0;
251 let items_per_chunk = (avg_throughput * target_time_ms / 1000.0) as usize;
252
253 items_per_chunk.clamp(1, 1000) }
255
256 pub fn config(&self) -> &StreamingConfig {
257 &self.config
258 }
259}
260
261impl Default for StreamProcessor {
262 fn default() -> Self {
263 Self::new(StreamingConfig::default())
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_streaming_config() {
273 let config = StreamingConfig::new(StreamingMode::FixedChunk(64))
274 .with_prefetch(2)
275 .with_checkpointing(100);
276
277 assert_eq!(config.mode, StreamingMode::FixedChunk(64));
278 assert_eq!(config.prefetch_chunks, 2);
279 assert_eq!(config.checkpoint_interval, Some(100));
280 }
281
282 #[test]
283 fn test_chunk_metadata() {
284 let metadata = ChunkMetadata::new(0, 0, 32, 100);
285 assert_eq!(metadata.chunk_id, 0);
286 assert_eq!(metadata.size, 32);
287 assert!(!metadata.is_last);
288
289 let last_metadata = ChunkMetadata::new(3, 96, 100, 100);
290 assert!(last_metadata.is_last);
291 }
292
293 #[test]
294 fn test_stream_result() {
295 let metadata = ChunkMetadata::new(0, 0, 32, 100);
296 let result: StreamResult<i32> = StreamResult::new(vec![1, 2, 3], metadata, 100.0);
297
298 assert_eq!(result.outputs.len(), 3);
299 let throughput = result.throughput_items_per_sec();
300 assert!(throughput > 0.0);
301 }
302
303 #[test]
304 fn test_chunk_iterator() {
305 let iter = ChunkIterator::new(100, 32);
306 assert_eq!(iter.num_chunks(), 4); let chunks: Vec<_> = iter.collect();
309 assert_eq!(chunks.len(), 4);
310 assert_eq!(chunks[0].size, 32);
311 assert_eq!(chunks[3].size, 4);
312 assert!(chunks[3].is_last);
313 }
314
315 #[test]
316 fn test_chunk_iterator_from_config() {
317 let config = StreamingConfig::new(StreamingMode::FixedChunk(25));
318 let iter = ChunkIterator::from_config(100, &config);
319
320 assert_eq!(iter.chunk_size, 25);
321 assert_eq!(iter.num_chunks(), 4);
322 }
323
324 #[test]
325 fn test_stream_processor_split() {
326 let batch = BatchResult::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
327 let config = StreamingConfig::new(StreamingMode::FixedChunk(3));
328 let processor = StreamProcessor::new(config);
329
330 let chunks = processor.split_batch(&batch);
331 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].1, vec![1, 2, 3]);
334 assert_eq!(chunks[1].1, vec![4, 5, 6]);
335 assert_eq!(chunks[2].1, vec![7, 8, 9]);
336 assert_eq!(chunks[3].1, vec![10]);
337 }
338
339 #[test]
340 fn test_stream_processor_merge() {
341 let metadata1 = ChunkMetadata::new(0, 0, 3, 10);
342 let metadata2 = ChunkMetadata::new(1, 3, 6, 10);
343 let metadata3 = ChunkMetadata::new(2, 6, 10, 10);
344
345 let results = vec![
346 StreamResult::new(vec![1, 2, 3], metadata1, 10.0),
347 StreamResult::new(vec![4, 5, 6], metadata2, 10.0),
348 StreamResult::new(vec![7, 8, 9, 10], metadata3, 10.0),
349 ];
350
351 let batch = StreamProcessor::merge_results(results);
352 assert_eq!(batch.len(), 10);
353 assert_eq!(batch.outputs, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
354 }
355
356 #[test]
357 fn test_adaptive_chunk_size() {
358 let processor = StreamProcessor::default();
359
360 let metadata = ChunkMetadata::new(0, 0, 100, 1000);
361 let results = vec![
362 StreamResult::new(vec![(); 100], metadata.clone(), 50.0), StreamResult::new(vec![(); 100], metadata.clone(), 100.0), StreamResult::new(vec![(); 100], metadata, 75.0), ];
366
367 let chunk_size = processor.adaptive_chunk_size(&results);
368 assert!(chunk_size > 0);
369 assert!(chunk_size <= 1000); }
371
372 #[test]
373 fn test_streaming_modes() {
374 assert_eq!(StreamingMode::None, StreamingConfig::default().mode);
375
376 let fixed = StreamingMode::FixedChunk(64);
377 assert_eq!(fixed, StreamingMode::FixedChunk(64));
378
379 let dynamic = StreamingMode::DynamicChunk {
380 target_memory_mb: 512,
381 };
382 match dynamic {
383 StreamingMode::DynamicChunk { target_memory_mb } => {
384 assert_eq!(target_memory_mb, 512);
385 }
386 _ => panic!("Wrong mode"),
387 }
388 }
389}