1use super::pools::{MemoryPool, MemoryPoolStats};
7use crate::{Result, TensorError};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11pub struct MultiStreamMemoryManager {
13 pools: Vec<MemoryPool>,
14 stream_assignment: Arc<Mutex<HashMap<usize, usize>>>, current_stream: Arc<Mutex<usize>>,
16}
17
18impl MultiStreamMemoryManager {
19 #[cfg(feature = "gpu")]
21 pub fn new(device_id: usize, num_streams: usize, pool_size_per_stream: usize) -> Result<Self> {
22 let mut pools = Vec::new();
23
24 for _ in 0..num_streams {
25 pools.push(MemoryPool::new(device_id, pool_size_per_stream)?);
26 }
27
28 Ok(Self {
29 pools,
30 stream_assignment: Arc::new(Mutex::new(HashMap::new())),
31 current_stream: Arc::new(Mutex::new(0)),
32 })
33 }
34
35 pub fn get_pool(&self, operation_id: usize) -> Result<&MemoryPool> {
37 let stream_assignment = self
38 .stream_assignment
39 .lock()
40 .expect("lock should not be poisoned");
41
42 let stream_id = if let Some(&stream_id) = stream_assignment.get(&operation_id) {
43 stream_id
44 } else {
45 let mut current_stream = self
47 .current_stream
48 .lock()
49 .expect("lock should not be poisoned");
50 let stream_id = *current_stream;
51 *current_stream = (*current_stream + 1) % self.pools.len();
52 stream_id
53 };
54
55 self.pools
56 .get(stream_id)
57 .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
58 }
59
60 pub fn assign_operation_to_stream(&self, operation_id: usize, stream_id: usize) -> Result<()> {
62 if stream_id >= self.pools.len() {
63 return Err(TensorError::invalid_argument(format!(
64 "Stream ID {} out of range. Available streams: {}",
65 stream_id,
66 self.pools.len()
67 )));
68 }
69
70 let mut stream_assignment = self
71 .stream_assignment
72 .lock()
73 .expect("lock should not be poisoned");
74 stream_assignment.insert(operation_id, stream_id);
75 Ok(())
76 }
77
78 pub fn unassign_operation(&self, operation_id: usize) {
80 let mut stream_assignment = self
81 .stream_assignment
82 .lock()
83 .expect("lock should not be poisoned");
84 stream_assignment.remove(&operation_id);
85 }
86
87 pub fn get_operation_stream(&self, operation_id: usize) -> Option<usize> {
89 let stream_assignment = self
90 .stream_assignment
91 .lock()
92 .expect("lock should not be poisoned");
93 stream_assignment.get(&operation_id).copied()
94 }
95
96 pub fn num_streams(&self) -> usize {
98 self.pools.len()
99 }
100
101 pub fn get_pool_by_stream(&self, stream_id: usize) -> Result<&MemoryPool> {
103 self.pools
104 .get(stream_id)
105 .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
106 }
107
108 pub fn stats(&self) -> Vec<MemoryPoolStats> {
110 self.pools.iter().map(|pool| pool.stats()).collect()
111 }
112
113 pub fn stream_stats(&self, stream_id: usize) -> Result<MemoryPoolStats> {
115 self.pools
116 .get(stream_id)
117 .map(|pool| pool.stats())
118 .ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
119 }
120
121 pub fn total_memory_usage(&self) -> (usize, usize) {
123 let mut total_allocated = 0;
124 let mut total_free = 0;
125
126 for pool in &self.pools {
127 let stats = pool.stats();
128 total_allocated += stats.total_allocated;
129 total_free += stats.total_free;
130 }
131
132 (total_allocated, total_free)
133 }
134
135 pub fn get_least_loaded_stream(&self) -> usize {
137 let mut min_load = usize::MAX;
138 let mut best_stream = 0;
139
140 for (i, pool) in self.pools.iter().enumerate() {
141 let stats = pool.stats();
142 if stats.total_allocated < min_load {
143 min_load = stats.total_allocated;
144 best_stream = i;
145 }
146 }
147
148 best_stream
149 }
150
151 pub fn get_stream_with_most_free_memory(&self) -> usize {
153 let mut max_free = 0;
154 let mut best_stream = 0;
155
156 for (i, pool) in self.pools.iter().enumerate() {
157 let stats = pool.stats();
158 if stats.total_free > max_free {
159 max_free = stats.total_free;
160 best_stream = i;
161 }
162 }
163
164 best_stream
165 }
166
167 pub fn balance_streams(&self) -> Result<usize> {
169 let mut reassignments = 0;
170 let target_load = {
171 let (total_allocated, _) = self.total_memory_usage();
172 total_allocated / self.pools.len()
173 };
174
175 let mut stream_assignment = self
176 .stream_assignment
177 .lock()
178 .expect("lock should not be poisoned");
179
180 let mut overloaded_streams = Vec::new();
182 let mut underloaded_streams = Vec::new();
183
184 for (i, pool) in self.pools.iter().enumerate() {
185 let stats = pool.stats();
186 if stats.total_allocated > target_load * 11 / 10 {
187 overloaded_streams.push(i);
189 } else if stats.total_allocated < target_load * 9 / 10 {
190 underloaded_streams.push(i);
191 }
192 }
193
194 let operations_to_reassign: Vec<_> = stream_assignment
196 .iter()
197 .filter(|(_, &stream_id)| overloaded_streams.contains(&stream_id))
198 .map(|(&op_id, &stream_id)| (op_id, stream_id))
199 .collect();
200
201 for (op_id, _old_stream) in operations_to_reassign {
202 if let Some(&new_stream) = underloaded_streams.first() {
203 stream_assignment.insert(op_id, new_stream);
204 reassignments += 1;
205
206 underloaded_streams.rotate_left(1);
208 }
209 }
210
211 Ok(reassignments)
212 }
213
214 pub fn generate_streams_report(&self) -> String {
216 let mut report = String::new();
217 report.push_str("=== Multi-Stream Memory Manager Report ===\n\n");
218
219 let (total_allocated, total_free) = self.total_memory_usage();
220 report.push_str(&format!(
221 "Total Memory - Allocated: {} bytes, Free: {} bytes\n",
222 total_allocated, total_free
223 ));
224 report.push_str(&format!("Number of Streams: {}\n\n", self.pools.len()));
225
226 for (i, pool) in self.pools.iter().enumerate() {
228 let stats = pool.stats();
229 report.push_str(&format!("Stream {}:\n", i));
230 report.push_str(&format!(" Allocated: {} bytes\n", stats.total_allocated));
231 report.push_str(&format!(" Free: {} bytes\n", stats.total_free));
232 report.push_str(&format!(" Blocks Allocated: {}\n", stats.blocks_allocated));
233 report.push_str(&format!(" Blocks Free: {}\n", stats.blocks_free));
234 report.push_str(&format!(
235 " Fragmentation Ratio: {:.2}\n",
236 stats.fragmentation_ratio
237 ));
238 report.push_str(&format!(
239 " Memory Pressure: {:.2}%\n",
240 stats.memory_pressure * 100.0
241 ));
242 report.push('\n');
243 }
244
245 let stream_assignment = self
247 .stream_assignment
248 .lock()
249 .expect("lock should not be poisoned");
250 if !stream_assignment.is_empty() {
251 report.push_str("Operation Assignments:\n");
252 for (op_id, stream_id) in stream_assignment.iter() {
253 report.push_str(&format!(" Operation {}: Stream {}\n", op_id, stream_id));
254 }
255 }
256
257 report
258 }
259
260 pub fn clear_assignments(&self) {
262 let mut stream_assignment = self
263 .stream_assignment
264 .lock()
265 .expect("lock should not be poisoned");
266 stream_assignment.clear();
267 }
268
269 pub fn get_operation_counts(&self) -> Vec<usize> {
271 let stream_assignment = self
272 .stream_assignment
273 .lock()
274 .expect("lock should not be poisoned");
275 let mut counts = vec![0; self.pools.len()];
276
277 for &stream_id in stream_assignment.values() {
278 if stream_id < counts.len() {
279 counts[stream_id] += 1;
280 }
281 }
282
283 counts
284 }
285
286 pub fn are_streams_balanced(&self, tolerance_percent: f32) -> bool {
288 let (total_allocated, _) = self.total_memory_usage();
289 if total_allocated == 0 {
290 return true; }
292
293 let target_load = total_allocated / self.pools.len();
294 let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
295
296 for pool in &self.pools {
297 let stats = pool.stats();
298 let deviation = stats.total_allocated.abs_diff(target_load);
299
300 if deviation > tolerance {
301 return false;
302 }
303 }
304
305 true
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
317 fn test_stream_assignment() {
318 let _assignments: HashMap<usize, usize> = HashMap::new();
322 let current_stream = 0;
323
324 let num_streams = 3;
326 let mut stream_id = current_stream;
327
328 for i in 0..6 {
329 let assigned_stream = stream_id;
331 stream_id = (stream_id + 1) % num_streams;
332
333 assert_eq!(assigned_stream, i % num_streams);
334 }
335 }
336
337 #[test]
338 fn test_load_balancing_logic() {
339 let target_load = 1000;
341 let tolerance = target_load / 10; let overloaded = 1200;
345 assert!(overloaded > target_load + tolerance);
346
347 let underloaded = 800;
349 assert!(underloaded < target_load - tolerance);
350
351 let balanced = 950;
353 assert!(balanced >= target_load - tolerance && balanced <= target_load + tolerance);
354 }
355
356 #[test]
357 fn test_stream_balancing_calculation() {
358 let total_allocated = 3000;
360 let num_streams = 3;
361 let target_load = total_allocated / num_streams; assert_eq!(target_load, 1000);
364
365 let tolerance_percent = 10.0;
367 let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
368 assert_eq!(tolerance, 100);
369
370 let stream_load: usize = 1150;
372 let deviation = stream_load.abs_diff(target_load);
373 assert_eq!(deviation, 150);
374 assert!(deviation > tolerance); }
376
377 #[test]
378 fn test_operation_count_tracking() {
379 let mut counts = vec![0; 3]; let assignments = vec![(1, 0), (2, 1), (3, 0), (4, 2), (5, 1)];
381
382 for (_, stream_id) in assignments {
383 if stream_id < counts.len() {
384 counts[stream_id] += 1;
385 }
386 }
387
388 assert_eq!(counts, vec![2, 2, 1]); }
390
391 #[test]
392 fn test_memory_usage_aggregation() {
393 let stream_stats = vec![
395 (500, 1500), (800, 1200),
397 (300, 1700),
398 ];
399
400 let mut total_allocated = 0;
401 let mut total_free = 0;
402
403 for (allocated, free) in stream_stats {
404 total_allocated += allocated;
405 total_free += free;
406 }
407
408 assert_eq!(total_allocated, 1600);
409 assert_eq!(total_free, 4400);
410 }
411
412 #[test]
413 fn test_least_loaded_stream_selection() {
414 let stream_loads = [1200, 800, 1000];
416
417 let mut min_load = usize::MAX;
418 let mut best_stream = 0;
419
420 for (i, &load) in stream_loads.iter().enumerate() {
421 if load < min_load {
422 min_load = load;
423 best_stream = i;
424 }
425 }
426
427 assert_eq!(best_stream, 1); assert_eq!(min_load, 800);
429 }
430
431 #[test]
432 fn test_most_free_memory_selection() {
433 let stream_free_memory = [500, 1200, 800];
435
436 let mut max_free = 0;
437 let mut best_stream = 0;
438
439 for (i, &free) in stream_free_memory.iter().enumerate() {
440 if free > max_free {
441 max_free = free;
442 best_stream = i;
443 }
444 }
445
446 assert_eq!(best_stream, 1); assert_eq!(max_free, 1200);
448 }
449}