1use std::collections::{HashMap, HashSet};
8
9use super::{EinsumGraph, OpType};
10use crate::error::IrError;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct TensorMemory {
15 pub tensor_idx: usize,
17 pub size_bytes: usize,
19 pub first_use: Option<usize>,
21 pub last_use: Option<usize>,
23}
24
25#[derive(Debug, Clone)]
27pub struct MemoryAnalysis {
28 pub tensors: Vec<TensorMemory>,
30 pub peak_memory_bytes: usize,
32 pub total_memory_bytes: usize,
34 pub avg_utilization: f64,
36 pub optimal_schedule: Vec<usize>,
38}
39
40impl MemoryAnalysis {
41 pub fn new() -> Self {
43 Self {
44 tensors: Vec::new(),
45 peak_memory_bytes: 0,
46 total_memory_bytes: 0,
47 avg_utilization: 0.0,
48 optimal_schedule: Vec::new(),
49 }
50 }
51
52 pub fn memory_waste_ratio(&self) -> f64 {
54 if self.peak_memory_bytes == 0 {
55 return 0.0;
56 }
57 let avg_memory = self.total_memory_bytes as f64 * self.avg_utilization;
58 (self.peak_memory_bytes as f64 - avg_memory) / self.peak_memory_bytes as f64
59 }
60}
61
62impl Default for MemoryAnalysis {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68pub fn analyze_memory(
88 graph: &EinsumGraph,
89 element_size_bytes: usize,
90) -> Result<MemoryAnalysis, IrError> {
91 if graph.nodes.is_empty() {
92 return Ok(MemoryAnalysis::new());
93 }
94
95 let tensor_lifetimes = analyze_tensor_lifetimes(graph);
97
98 let mut tensor_memories = Vec::new();
100 for (tensor_idx, (first_use, last_use)) in tensor_lifetimes.iter().enumerate() {
101 let size_bytes = estimate_tensor_size(graph, tensor_idx, element_size_bytes);
103 tensor_memories.push(TensorMemory {
104 tensor_idx,
105 size_bytes,
106 first_use: *first_use,
107 last_use: *last_use,
108 });
109 }
110
111 let peak_memory_bytes = compute_peak_memory(graph, &tensor_memories);
113
114 let total_memory_bytes = tensor_memories.iter().map(|t| t.size_bytes).sum();
116
117 let avg_utilization = if graph.nodes.is_empty() {
119 0.0
120 } else {
121 let total_live: usize = (0..graph.nodes.len())
123 .map(|step| count_live_tensors_at_step(step, &tensor_memories))
124 .sum();
125 let avg_live = total_live as f64 / graph.nodes.len() as f64;
126 let avg_memory = avg_live * (total_memory_bytes as f64 / tensor_memories.len() as f64);
127 if peak_memory_bytes > 0 {
128 avg_memory / peak_memory_bytes as f64
129 } else {
130 0.0
131 }
132 };
133
134 let optimal_schedule = generate_memory_optimal_schedule(graph, &tensor_memories)?;
136
137 Ok(MemoryAnalysis {
138 tensors: tensor_memories,
139 peak_memory_bytes,
140 total_memory_bytes,
141 avg_utilization,
142 optimal_schedule,
143 })
144}
145
146fn analyze_tensor_lifetimes(graph: &EinsumGraph) -> Vec<(Option<usize>, Option<usize>)> {
148 let mut lifetimes = vec![(None, None); graph.tensors.len()];
149
150 for (node_idx, node) in graph.nodes.iter().enumerate() {
151 for &input_idx in &node.inputs {
153 if input_idx < lifetimes.len() {
154 let (ref mut first, ref mut last) = lifetimes[input_idx];
155 *first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
156 *last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
157 }
158 }
159
160 for &output_idx in &node.outputs {
162 if output_idx < lifetimes.len() {
163 let (ref mut first, ref mut last) = lifetimes[output_idx];
164 *first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
165 *last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
166 }
167 }
168 }
169
170 lifetimes
171}
172
173fn estimate_tensor_size(
175 _graph: &EinsumGraph,
176 _tensor_idx: usize,
177 element_size_bytes: usize,
178) -> usize {
179 1000 * element_size_bytes
182}
183
184fn compute_peak_memory(graph: &EinsumGraph, tensors: &[TensorMemory]) -> usize {
186 let mut peak = 0;
187
188 for step in 0..graph.nodes.len() {
189 let live_memory: usize = tensors
190 .iter()
191 .filter(|t| is_tensor_live_at_step(t, step))
192 .map(|t| t.size_bytes)
193 .sum();
194 peak = peak.max(live_memory);
195 }
196
197 peak
198}
199
200fn is_tensor_live_at_step(tensor: &TensorMemory, step: usize) -> bool {
202 match (tensor.first_use, tensor.last_use) {
203 (Some(first), Some(last)) => step >= first && step <= last,
204 _ => false,
205 }
206}
207
208fn count_live_tensors_at_step(step: usize, tensors: &[TensorMemory]) -> usize {
210 tensors
211 .iter()
212 .filter(|t| is_tensor_live_at_step(t, step))
213 .count()
214}
215
216fn generate_memory_optimal_schedule(
222 graph: &EinsumGraph,
223 _tensors: &[TensorMemory],
224) -> Result<Vec<usize>, IrError> {
225 let dependencies = build_dependencies(graph);
227
228 let schedule = topological_sort_memory_aware(graph, &dependencies);
230
231 Ok(schedule)
232}
233
234fn build_dependencies(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
236 let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
237 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
238
239 for (node_idx, node) in graph.nodes.iter().enumerate() {
241 for &output_idx in &node.outputs {
242 tensor_producer.insert(output_idx, node_idx);
243 }
244 }
245
246 for (node_idx, node) in graph.nodes.iter().enumerate() {
248 let mut deps = Vec::new();
249 for &input_idx in &node.inputs {
250 if let Some(&producer) = tensor_producer.get(&input_idx) {
251 if producer != node_idx {
252 deps.push(producer);
253 }
254 }
255 }
256 dependencies.insert(node_idx, deps);
257 }
258
259 dependencies
260}
261
262fn topological_sort_memory_aware(
264 graph: &EinsumGraph,
265 dependencies: &HashMap<usize, Vec<usize>>,
266) -> Vec<usize> {
267 let mut schedule = Vec::new();
268 let mut scheduled = HashSet::new();
269 let mut in_degree = vec![0; graph.nodes.len()];
270
271 for deps in dependencies.values() {
273 for &dep in deps {
274 if dep < in_degree.len() {
275 in_degree[dep] += 1;
276 }
277 }
278 }
279
280 while schedule.len() < graph.nodes.len() {
282 let ready: Vec<usize> = (0..graph.nodes.len())
284 .filter(|&i| !scheduled.contains(&i) && in_degree[i] == 0)
285 .collect();
286
287 if ready.is_empty() {
288 break; }
290
291 let next = select_next_node_memory_aware(graph, &ready);
293 schedule.push(next);
294 scheduled.insert(next);
295
296 if let Some(deps) = dependencies.get(&next) {
298 for &dep in deps {
299 if dep < in_degree.len() {
300 let current_degree: usize = in_degree[dep];
301 in_degree[dep] = current_degree.saturating_sub(1);
302 }
303 }
304 }
305 }
306
307 schedule
308}
309
310fn select_next_node_memory_aware(graph: &EinsumGraph, candidates: &[usize]) -> usize {
312 candidates
314 .iter()
315 .min_by_key(|&&idx| {
316 graph
317 .nodes
318 .get(idx)
319 .map(|n| n.outputs.len())
320 .unwrap_or(usize::MAX)
321 })
322 .copied()
323 .unwrap_or(0)
324}
325
326pub fn analyze_inplace_opportunities(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
331 let mut inplace_candidates = Vec::new();
332
333 for (node_idx, node) in graph.nodes.iter().enumerate() {
334 if can_be_inplace(&node.op) && has_single_input_use(graph, node_idx) {
335 inplace_candidates.push(node_idx);
336 }
337 }
338
339 Ok(inplace_candidates)
340}
341
342fn can_be_inplace(op_type: &OpType) -> bool {
344 matches!(op_type, OpType::ElemUnary { .. })
346}
347
348fn has_single_input_use(graph: &EinsumGraph, node_idx: usize) -> bool {
350 let node = &graph.nodes[node_idx];
351 if node.inputs.is_empty() {
352 return false;
353 }
354
355 let input_tensor = node.inputs[0];
356
357 let use_count = graph
359 .nodes
360 .iter()
361 .filter(|n| n.inputs.contains(&input_tensor))
362 .count();
363
364 use_count == 1
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::graph::EinsumNode;
371
372 #[test]
373 fn test_memory_analysis_default() {
374 let analysis = MemoryAnalysis::default();
375 assert_eq!(analysis.peak_memory_bytes, 0);
376 assert_eq!(analysis.total_memory_bytes, 0);
377 }
378
379 #[test]
380 fn test_analyze_empty_graph() {
381 let graph = EinsumGraph::new();
382 let analysis = analyze_memory(&graph, 8).unwrap();
383 assert_eq!(analysis.peak_memory_bytes, 0);
384 assert_eq!(analysis.tensors.len(), 0);
385 }
386
387 #[test]
388 fn test_analyze_single_node() {
389 let mut graph = EinsumGraph::new();
390 let a = graph.add_tensor("A");
391 let b = graph.add_tensor("B");
392 graph
393 .add_node(EinsumNode::elem_unary("relu", a, b))
394 .unwrap();
395
396 let analysis = analyze_memory(&graph, 8).unwrap();
397 assert!(analysis.peak_memory_bytes > 0);
398 assert_eq!(analysis.tensors.len(), 2);
399 }
400
401 #[test]
402 fn test_tensor_lifetime_single_use() {
403 let mut graph = EinsumGraph::new();
404 let a = graph.add_tensor("A");
405 let b = graph.add_tensor("B");
406 graph
407 .add_node(EinsumNode::elem_unary("relu", a, b))
408 .unwrap();
409
410 let lifetimes = analyze_tensor_lifetimes(&graph);
411 assert_eq!(lifetimes[a], (Some(0), Some(0)));
412 assert_eq!(lifetimes[b], (Some(0), Some(0)));
413 }
414
415 #[test]
416 fn test_tensor_lifetime_multiple_uses() {
417 let mut graph = EinsumGraph::new();
418 let a = graph.add_tensor("A");
419 let b = graph.add_tensor("B");
420 let c = graph.add_tensor("C");
421
422 graph
423 .add_node(EinsumNode::elem_unary("relu", a, b))
424 .unwrap();
425 graph
426 .add_node(EinsumNode::elem_unary("tanh", b, c))
427 .unwrap();
428
429 let lifetimes = analyze_tensor_lifetimes(&graph);
430 assert_eq!(lifetimes[b], (Some(0), Some(1)));
431 }
432
433 #[test]
434 fn test_estimate_tensor_size() {
435 let graph = EinsumGraph::new();
436 let size = estimate_tensor_size(&graph, 0, 8);
437 assert_eq!(size, 8000); }
439
440 #[test]
441 fn test_is_tensor_live_at_step() {
442 let tensor = TensorMemory {
443 tensor_idx: 0,
444 size_bytes: 1000,
445 first_use: Some(2),
446 last_use: Some(5),
447 };
448
449 assert!(!is_tensor_live_at_step(&tensor, 0));
450 assert!(!is_tensor_live_at_step(&tensor, 1));
451 assert!(is_tensor_live_at_step(&tensor, 2));
452 assert!(is_tensor_live_at_step(&tensor, 3));
453 assert!(is_tensor_live_at_step(&tensor, 5));
454 assert!(!is_tensor_live_at_step(&tensor, 6));
455 }
456
457 #[test]
458 fn test_memory_waste_ratio_zero_peak() {
459 let analysis = MemoryAnalysis {
460 peak_memory_bytes: 0,
461 total_memory_bytes: 1000,
462 avg_utilization: 0.5,
463 ..Default::default()
464 };
465 assert_eq!(analysis.memory_waste_ratio(), 0.0);
466 }
467
468 #[test]
469 fn test_can_be_inplace() {
470 assert!(can_be_inplace(&OpType::ElemUnary {
471 op: "relu".to_string()
472 }));
473 assert!(!can_be_inplace(&OpType::Einsum {
474 spec: "ij,jk->ik".to_string()
475 }));
476 }
477
478 #[test]
479 fn test_analyze_inplace_opportunities_empty() {
480 let graph = EinsumGraph::new();
481 let candidates = analyze_inplace_opportunities(&graph).unwrap();
482 assert!(candidates.is_empty());
483 }
484
485 #[test]
486 fn test_analyze_inplace_single_use() {
487 let mut graph = EinsumGraph::new();
488 let a = graph.add_tensor("A");
489 let b = graph.add_tensor("B");
490 graph
491 .add_node(EinsumNode::elem_unary("relu", a, b))
492 .unwrap();
493
494 let candidates = analyze_inplace_opportunities(&graph).unwrap();
495 assert_eq!(candidates.len(), 1);
496 }
497
498 #[test]
499 fn test_build_dependencies() {
500 let mut graph = EinsumGraph::new();
501 let a = graph.add_tensor("A");
502 let b = graph.add_tensor("B");
503 let c = graph.add_tensor("C");
504
505 graph
506 .add_node(EinsumNode::elem_unary("relu", a, b))
507 .unwrap();
508 graph
509 .add_node(EinsumNode::elem_unary("tanh", b, c))
510 .unwrap();
511
512 let deps = build_dependencies(&graph);
513 assert_eq!(deps.get(&0).unwrap().len(), 0); assert_eq!(deps.get(&1).unwrap(), &vec![0]); }
516
517 #[test]
518 fn test_topological_sort_simple() {
519 let mut graph = EinsumGraph::new();
520 let a = graph.add_tensor("A");
521 let b = graph.add_tensor("B");
522
523 graph
524 .add_node(EinsumNode::elem_unary("relu", a, b))
525 .unwrap();
526
527 let deps = build_dependencies(&graph);
528 let schedule = topological_sort_memory_aware(&graph, &deps);
529 assert_eq!(schedule, vec![0]);
530 }
531}