Skip to main content

tenflowers_core/memory/
streams.rs

1//! Multi-stream memory management for concurrent operations
2//!
3//! This module provides concurrent memory management across multiple streams,
4//! enabling efficient parallel GPU operations.
5
6use super::pools::{MemoryPool, MemoryPoolStats};
7use crate::{Result, TensorError};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11/// Multi-stream memory management for concurrent operations
12pub struct MultiStreamMemoryManager {
13    pools: Vec<MemoryPool>,
14    stream_assignment: Arc<Mutex<HashMap<usize, usize>>>, // operation_id -> stream_id
15    current_stream: Arc<Mutex<usize>>,
16}
17
18impl MultiStreamMemoryManager {
19    /// Create a new multi-stream memory manager
20    #[cfg(feature = "gpu")]
21    pub fn new(device_id: usize, num_streams: usize, pool_size_per_stream: usize) -> Result<Self> {
22        let mut pools = Vec::new();
23
24        for _ in 0..num_streams {
25            pools.push(MemoryPool::new(device_id, pool_size_per_stream)?);
26        }
27
28        Ok(Self {
29            pools,
30            stream_assignment: Arc::new(Mutex::new(HashMap::new())),
31            current_stream: Arc::new(Mutex::new(0)),
32        })
33    }
34
35    /// Get the appropriate memory pool for an operation
36    pub fn get_pool(&self, operation_id: usize) -> Result<&MemoryPool> {
37        let stream_assignment = self
38            .stream_assignment
39            .lock()
40            .expect("lock should not be poisoned");
41
42        let stream_id = if let Some(&stream_id) = stream_assignment.get(&operation_id) {
43            stream_id
44        } else {
45            // Assign to current stream and rotate
46            let mut current_stream = self
47                .current_stream
48                .lock()
49                .expect("lock should not be poisoned");
50            let stream_id = *current_stream;
51            *current_stream = (*current_stream + 1) % self.pools.len();
52            stream_id
53        };
54
55        self.pools
56            .get(stream_id)
57            .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
58    }
59
60    /// Assign a specific operation to a specific stream
61    pub fn assign_operation_to_stream(&self, operation_id: usize, stream_id: usize) -> Result<()> {
62        if stream_id >= self.pools.len() {
63            return Err(TensorError::invalid_argument(format!(
64                "Stream ID {} out of range. Available streams: {}",
65                stream_id,
66                self.pools.len()
67            )));
68        }
69
70        let mut stream_assignment = self
71            .stream_assignment
72            .lock()
73            .expect("lock should not be poisoned");
74        stream_assignment.insert(operation_id, stream_id);
75        Ok(())
76    }
77
78    /// Remove an operation's stream assignment
79    pub fn unassign_operation(&self, operation_id: usize) {
80        let mut stream_assignment = self
81            .stream_assignment
82            .lock()
83            .expect("lock should not be poisoned");
84        stream_assignment.remove(&operation_id);
85    }
86
87    /// Get the stream ID for a specific operation
88    pub fn get_operation_stream(&self, operation_id: usize) -> Option<usize> {
89        let stream_assignment = self
90            .stream_assignment
91            .lock()
92            .expect("lock should not be poisoned");
93        stream_assignment.get(&operation_id).copied()
94    }
95
96    /// Get the number of available streams
97    pub fn num_streams(&self) -> usize {
98        self.pools.len()
99    }
100
101    /// Get a specific pool by stream ID
102    pub fn get_pool_by_stream(&self, stream_id: usize) -> Result<&MemoryPool> {
103        self.pools
104            .get(stream_id)
105            .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
106    }
107
108    /// Get statistics for all streams
109    pub fn stats(&self) -> Vec<MemoryPoolStats> {
110        self.pools.iter().map(|pool| pool.stats()).collect()
111    }
112
113    /// Get statistics for a specific stream
114    pub fn stream_stats(&self, stream_id: usize) -> Result<MemoryPoolStats> {
115        self.pools
116            .get(stream_id)
117            .map(|pool| pool.stats())
118            .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
119    }
120
121    /// Get total memory usage across all streams
122    pub fn total_memory_usage(&self) -> (usize, usize) {
123        let mut total_allocated = 0;
124        let mut total_free = 0;
125
126        for pool in &self.pools {
127            let stats = pool.stats();
128            total_allocated += stats.total_allocated;
129            total_free += stats.total_free;
130        }
131
132        (total_allocated, total_free)
133    }
134
135    /// Get the least loaded stream (for load balancing)
136    pub fn get_least_loaded_stream(&self) -> usize {
137        let mut min_load = usize::MAX;
138        let mut best_stream = 0;
139
140        for (i, pool) in self.pools.iter().enumerate() {
141            let stats = pool.stats();
142            if stats.total_allocated < min_load {
143                min_load = stats.total_allocated;
144                best_stream = i;
145            }
146        }
147
148        best_stream
149    }
150
151    /// Get the stream with the most free memory
152    pub fn get_stream_with_most_free_memory(&self) -> usize {
153        let mut max_free = 0;
154        let mut best_stream = 0;
155
156        for (i, pool) in self.pools.iter().enumerate() {
157            let stats = pool.stats();
158            if stats.total_free > max_free {
159                max_free = stats.total_free;
160                best_stream = i;
161            }
162        }
163
164        best_stream
165    }
166
167    /// Balance memory across streams by reassigning operations
168    pub fn balance_streams(&self) -> Result<usize> {
169        let mut reassignments = 0;
170        let target_load = {
171            let (total_allocated, _) = self.total_memory_usage();
172            total_allocated / self.pools.len()
173        };
174
175        let mut stream_assignment = self
176            .stream_assignment
177            .lock()
178            .expect("lock should not be poisoned");
179
180        // Identify overloaded and underloaded streams
181        let mut overloaded_streams = Vec::new();
182        let mut underloaded_streams = Vec::new();
183
184        for (i, pool) in self.pools.iter().enumerate() {
185            let stats = pool.stats();
186            if stats.total_allocated > target_load * 11 / 10 {
187                // 10% tolerance
188                overloaded_streams.push(i);
189            } else if stats.total_allocated < target_load * 9 / 10 {
190                underloaded_streams.push(i);
191            }
192        }
193
194        // Reassign operations from overloaded to underloaded streams
195        let operations_to_reassign: Vec<_> = stream_assignment
196            .iter()
197            .filter(|(_, &stream_id)| overloaded_streams.contains(&stream_id))
198            .map(|(&op_id, &stream_id)| (op_id, stream_id))
199            .collect();
200
201        for (op_id, _old_stream) in operations_to_reassign {
202            if let Some(&new_stream) = underloaded_streams.first() {
203                stream_assignment.insert(op_id, new_stream);
204                reassignments += 1;
205
206                // Rotate underloaded streams for fair distribution
207                underloaded_streams.rotate_left(1);
208            }
209        }
210
211        Ok(reassignments)
212    }
213
214    /// Generate a comprehensive report of all streams
215    pub fn generate_streams_report(&self) -> String {
216        let mut report = String::new();
217        report.push_str("=== Multi-Stream Memory Manager Report ===\n\n");
218
219        let (total_allocated, total_free) = self.total_memory_usage();
220        report.push_str(&format!(
221            "Total Memory - Allocated: {} bytes, Free: {} bytes\n",
222            total_allocated, total_free
223        ));
224        report.push_str(&format!("Number of Streams: {}\n\n", self.pools.len()));
225
226        // Per-stream statistics
227        for (i, pool) in self.pools.iter().enumerate() {
228            let stats = pool.stats();
229            report.push_str(&format!("Stream {}:\n", i));
230            report.push_str(&format!("  Allocated: {} bytes\n", stats.total_allocated));
231            report.push_str(&format!("  Free: {} bytes\n", stats.total_free));
232            report.push_str(&format!("  Blocks Allocated: {}\n", stats.blocks_allocated));
233            report.push_str(&format!("  Blocks Free: {}\n", stats.blocks_free));
234            report.push_str(&format!(
235                "  Fragmentation Ratio: {:.2}\n",
236                stats.fragmentation_ratio
237            ));
238            report.push_str(&format!(
239                "  Memory Pressure: {:.2}%\n",
240                stats.memory_pressure * 100.0
241            ));
242            report.push('\n');
243        }
244
245        // Operation assignments
246        let stream_assignment = self
247            .stream_assignment
248            .lock()
249            .expect("lock should not be poisoned");
250        if !stream_assignment.is_empty() {
251            report.push_str("Operation Assignments:\n");
252            for (op_id, stream_id) in stream_assignment.iter() {
253                report.push_str(&format!("  Operation {}: Stream {}\n", op_id, stream_id));
254            }
255        }
256
257        report
258    }
259
260    /// Clear all operation assignments
261    pub fn clear_assignments(&self) {
262        let mut stream_assignment = self
263            .stream_assignment
264            .lock()
265            .expect("lock should not be poisoned");
266        stream_assignment.clear();
267    }
268
269    /// Get operation count per stream
270    pub fn get_operation_counts(&self) -> Vec<usize> {
271        let stream_assignment = self
272            .stream_assignment
273            .lock()
274            .expect("lock should not be poisoned");
275        let mut counts = vec![0; self.pools.len()];
276
277        for &stream_id in stream_assignment.values() {
278            if stream_id < counts.len() {
279                counts[stream_id] += 1;
280            }
281        }
282
283        counts
284    }
285
286    /// Check if streams are balanced (within tolerance)
287    pub fn are_streams_balanced(&self, tolerance_percent: f32) -> bool {
288        let (total_allocated, _) = self.total_memory_usage();
289        if total_allocated == 0 {
290            return true; // No memory allocated, considered balanced
291        }
292
293        let target_load = total_allocated / self.pools.len();
294        let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
295
296        for pool in &self.pools {
297            let stats = pool.stats();
298            let deviation = stats.total_allocated.abs_diff(target_load);
299
300            if deviation > tolerance {
301                return false;
302            }
303        }
304
305        true
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    // Note: These tests would require GPU context in a real environment
314    // For now, we test the logic that doesn't require actual GPU allocation
315
316    #[test]
317    fn test_stream_assignment() {
318        // This test would need to be adapted for actual GPU context
319        // For now, test the assignment logic conceptually
320
321        let _assignments: HashMap<usize, usize> = HashMap::new();
322        let current_stream = 0;
323
324        // Test round-robin assignment logic
325        let num_streams = 3;
326        let mut stream_id = current_stream;
327
328        for i in 0..6 {
329            // Simulate assignment
330            let assigned_stream = stream_id;
331            stream_id = (stream_id + 1) % num_streams;
332
333            assert_eq!(assigned_stream, i % num_streams);
334        }
335    }
336
337    #[test]
338    fn test_load_balancing_logic() {
339        // Test the balancing algorithm logic
340        let target_load = 1000;
341        let tolerance = target_load / 10; // 10% tolerance
342
343        // Test overloaded condition
344        let overloaded = 1200;
345        assert!(overloaded > target_load + tolerance);
346
347        // Test underloaded condition
348        let underloaded = 800;
349        assert!(underloaded < target_load - tolerance);
350
351        // Test balanced condition
352        let balanced = 950;
353        assert!(balanced >= target_load - tolerance && balanced <= target_load + tolerance);
354    }
355
356    #[test]
357    fn test_stream_balancing_calculation() {
358        // Test stream balance calculation
359        let total_allocated = 3000;
360        let num_streams = 3;
361        let target_load = total_allocated / num_streams; // 1000
362
363        assert_eq!(target_load, 1000);
364
365        // Test tolerance calculation
366        let tolerance_percent = 10.0;
367        let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
368        assert_eq!(tolerance, 100);
369
370        // Test deviation calculation
371        let stream_load: usize = 1150;
372        let deviation = stream_load.abs_diff(target_load);
373        assert_eq!(deviation, 150);
374        assert!(deviation > tolerance); // This stream would be considered unbalanced
375    }
376
377    #[test]
378    fn test_operation_count_tracking() {
379        let mut counts = vec![0; 3]; // 3 streams
380        let assignments = vec![(1, 0), (2, 1), (3, 0), (4, 2), (5, 1)];
381
382        for (_, stream_id) in assignments {
383            if stream_id < counts.len() {
384                counts[stream_id] += 1;
385            }
386        }
387
388        assert_eq!(counts, vec![2, 2, 1]); // Distribution: stream 0: 2, stream 1: 2, stream 2: 1
389    }
390
391    #[test]
392    fn test_memory_usage_aggregation() {
393        // Test total memory calculation logic
394        let stream_stats = vec![
395            (500, 1500), // allocated, free
396            (800, 1200),
397            (300, 1700),
398        ];
399
400        let mut total_allocated = 0;
401        let mut total_free = 0;
402
403        for (allocated, free) in stream_stats {
404            total_allocated += allocated;
405            total_free += free;
406        }
407
408        assert_eq!(total_allocated, 1600);
409        assert_eq!(total_free, 4400);
410    }
411
412    #[test]
413    fn test_least_loaded_stream_selection() {
414        // Test logic for finding least loaded stream
415        let stream_loads = [1200, 800, 1000];
416
417        let mut min_load = usize::MAX;
418        let mut best_stream = 0;
419
420        for (i, &load) in stream_loads.iter().enumerate() {
421            if load < min_load {
422                min_load = load;
423                best_stream = i;
424            }
425        }
426
427        assert_eq!(best_stream, 1); // Stream 1 has load 800, which is minimum
428        assert_eq!(min_load, 800);
429    }
430
431    #[test]
432    fn test_most_free_memory_selection() {
433        // Test logic for finding stream with most free memory
434        let stream_free_memory = [500, 1200, 800];
435
436        let mut max_free = 0;
437        let mut best_stream = 0;
438
439        for (i, &free) in stream_free_memory.iter().enumerate() {
440            if free > max_free {
441                max_free = free;
442                best_stream = i;
443            }
444        }
445
446        assert_eq!(best_stream, 1); // Stream 1 has 1200 free, which is maximum
447        assert_eq!(max_free, 1200);
448    }
449}