Skip to main content

tensorlogic_ir/graph/
memory.rs

1//! Memory optimization and layout analysis.
2//!
3//! This module provides utilities for analyzing and optimizing memory usage
4//! in computation graphs. It includes memory footprint estimation, tensor
5//! lifetime analysis, and operation reordering to minimize peak memory usage.
6
7use std::collections::{HashMap, HashSet};
8
9use super::{EinsumGraph, OpType};
10use crate::error::IrError;
11
12/// Memory footprint estimate for a tensor
13#[derive(Debug, Clone, PartialEq)]
14pub struct TensorMemory {
15    /// Tensor index
16    pub tensor_idx: usize,
17    /// Estimated size in bytes
18    pub size_bytes: usize,
19    /// First node that uses this tensor
20    pub first_use: Option<usize>,
21    /// Last node that uses this tensor
22    pub last_use: Option<usize>,
23}
24
25/// Memory optimization analysis result
26#[derive(Debug, Clone)]
27pub struct MemoryAnalysis {
28    /// Memory information for each tensor
29    pub tensors: Vec<TensorMemory>,
30    /// Peak memory usage in bytes
31    pub peak_memory_bytes: usize,
32    /// Total memory allocated across all tensors
33    pub total_memory_bytes: usize,
34    /// Average memory utilization (0.0 to 1.0)
35    pub avg_utilization: f64,
36    /// Suggested operation execution order for minimal peak memory
37    pub optimal_schedule: Vec<usize>,
38}
39
40impl MemoryAnalysis {
41    /// Create new empty analysis
42    pub fn new() -> Self {
43        Self {
44            tensors: Vec::new(),
45            peak_memory_bytes: 0,
46            total_memory_bytes: 0,
47            avg_utilization: 0.0,
48            optimal_schedule: Vec::new(),
49        }
50    }
51
52    /// Get memory waste (difference between peak and average)
53    pub fn memory_waste_ratio(&self) -> f64 {
54        if self.peak_memory_bytes == 0 {
55            return 0.0;
56        }
57        let avg_memory = self.total_memory_bytes as f64 * self.avg_utilization;
58        (self.peak_memory_bytes as f64 - avg_memory) / self.peak_memory_bytes as f64
59    }
60}
61
62impl Default for MemoryAnalysis {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Analyze memory usage patterns in a computation graph
69///
70/// This function performs a comprehensive analysis of memory usage including:
71/// - Tensor lifetime analysis (first use to last use)
72/// - Peak memory estimation
73/// - Memory utilization statistics
74///
75/// # Example
76///
77/// ```rust
78/// use tensorlogic_ir::{EinsumGraph, analyze_memory};
79///
80/// let mut graph = EinsumGraph::new();
81/// // Build your graph...
82///
83/// let analysis = analyze_memory(&graph, 8).unwrap();
84/// println!("Peak memory: {} bytes", analysis.peak_memory_bytes);
85/// println!("Memory waste ratio: {:.2}%", analysis.memory_waste_ratio() * 100.0);
86/// ```
87pub fn analyze_memory(
88    graph: &EinsumGraph,
89    element_size_bytes: usize,
90) -> Result<MemoryAnalysis, IrError> {
91    if graph.nodes.is_empty() {
92        return Ok(MemoryAnalysis::new());
93    }
94
95    // Analyze tensor lifetimes
96    let tensor_lifetimes = analyze_tensor_lifetimes(graph);
97
98    // Estimate tensor sizes (simplified: assume all tensors are same size)
99    let mut tensor_memories = Vec::new();
100    for (tensor_idx, (first_use, last_use)) in tensor_lifetimes.iter().enumerate() {
101        // Simplified size estimation
102        let size_bytes = estimate_tensor_size(graph, tensor_idx, element_size_bytes);
103        tensor_memories.push(TensorMemory {
104            tensor_idx,
105            size_bytes,
106            first_use: *first_use,
107            last_use: *last_use,
108        });
109    }
110
111    // Compute peak memory usage
112    let peak_memory_bytes = compute_peak_memory(graph, &tensor_memories);
113
114    // Compute total memory
115    let total_memory_bytes = tensor_memories.iter().map(|t| t.size_bytes).sum();
116
117    // Estimate average utilization
118    let avg_utilization = if graph.nodes.is_empty() {
119        0.0
120    } else {
121        // Average number of live tensors at each step
122        let total_live: usize = (0..graph.nodes.len())
123            .map(|step| count_live_tensors_at_step(step, &tensor_memories))
124            .sum();
125        let avg_live = total_live as f64 / graph.nodes.len() as f64;
126        let avg_memory = avg_live * (total_memory_bytes as f64 / tensor_memories.len() as f64);
127        if peak_memory_bytes > 0 {
128            avg_memory / peak_memory_bytes as f64
129        } else {
130            0.0
131        }
132    };
133
134    // Generate optimal schedule
135    let optimal_schedule = generate_memory_optimal_schedule(graph, &tensor_memories)?;
136
137    Ok(MemoryAnalysis {
138        tensors: tensor_memories,
139        peak_memory_bytes,
140        total_memory_bytes,
141        avg_utilization,
142        optimal_schedule,
143    })
144}
145
146/// Analyze when each tensor is first and last used
147fn analyze_tensor_lifetimes(graph: &EinsumGraph) -> Vec<(Option<usize>, Option<usize>)> {
148    let mut lifetimes = vec![(None, None); graph.tensors.len()];
149
150    for (node_idx, node) in graph.nodes.iter().enumerate() {
151        // Update first/last use for inputs
152        for &input_idx in &node.inputs {
153            if input_idx < lifetimes.len() {
154                let (ref mut first, ref mut last) = lifetimes[input_idx];
155                *first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
156                *last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
157            }
158        }
159
160        // Update first/last use for outputs
161        for &output_idx in &node.outputs {
162            if output_idx < lifetimes.len() {
163                let (ref mut first, ref mut last) = lifetimes[output_idx];
164                *first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
165                *last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
166            }
167        }
168    }
169
170    lifetimes
171}
172
173/// Estimate the size of a tensor in bytes (simplified)
174fn estimate_tensor_size(
175    _graph: &EinsumGraph,
176    _tensor_idx: usize,
177    element_size_bytes: usize,
178) -> usize {
179    // Simplified: assume 1000 elements per tensor
180    // In practice, this would use shape information
181    1000 * element_size_bytes
182}
183
184/// Compute peak memory usage across all execution steps
185fn compute_peak_memory(graph: &EinsumGraph, tensors: &[TensorMemory]) -> usize {
186    let mut peak = 0;
187
188    for step in 0..graph.nodes.len() {
189        let live_memory: usize = tensors
190            .iter()
191            .filter(|t| is_tensor_live_at_step(t, step))
192            .map(|t| t.size_bytes)
193            .sum();
194        peak = peak.max(live_memory);
195    }
196
197    peak
198}
199
200/// Check if a tensor is live at a given execution step
201fn is_tensor_live_at_step(tensor: &TensorMemory, step: usize) -> bool {
202    match (tensor.first_use, tensor.last_use) {
203        (Some(first), Some(last)) => step >= first && step <= last,
204        _ => false,
205    }
206}
207
208/// Count how many tensors are live at a given step
209fn count_live_tensors_at_step(step: usize, tensors: &[TensorMemory]) -> usize {
210    tensors
211        .iter()
212        .filter(|t| is_tensor_live_at_step(t, step))
213        .count()
214}
215
216/// Generate an execution schedule that minimizes peak memory usage
217///
218/// This uses a greedy algorithm that prioritizes operations that:
219/// 1. Free the most memory (last use of large tensors)
220/// 2. Have dependencies satisfied
221fn generate_memory_optimal_schedule(
222    graph: &EinsumGraph,
223    _tensors: &[TensorMemory],
224) -> Result<Vec<usize>, IrError> {
225    // Build dependency graph
226    let dependencies = build_dependencies(graph);
227
228    // Topological sort with memory-aware ordering
229    let schedule = topological_sort_memory_aware(graph, &dependencies);
230
231    Ok(schedule)
232}
233
234/// Build dependency map for the graph
235fn build_dependencies(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
236    let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
237    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
238
239    // Map each tensor to its producer node
240    for (node_idx, node) in graph.nodes.iter().enumerate() {
241        for &output_idx in &node.outputs {
242            tensor_producer.insert(output_idx, node_idx);
243        }
244    }
245
246    // Build dependencies
247    for (node_idx, node) in graph.nodes.iter().enumerate() {
248        let mut deps = Vec::new();
249        for &input_idx in &node.inputs {
250            if let Some(&producer) = tensor_producer.get(&input_idx) {
251                if producer != node_idx {
252                    deps.push(producer);
253                }
254            }
255        }
256        dependencies.insert(node_idx, deps);
257    }
258
259    dependencies
260}
261
262/// Topological sort with memory-aware heuristics
263fn topological_sort_memory_aware(
264    graph: &EinsumGraph,
265    dependencies: &HashMap<usize, Vec<usize>>,
266) -> Vec<usize> {
267    let mut schedule = Vec::new();
268    let mut scheduled = HashSet::new();
269    let mut in_degree = vec![0; graph.nodes.len()];
270
271    // Calculate in-degrees
272    for deps in dependencies.values() {
273        for &dep in deps {
274            if dep < in_degree.len() {
275                in_degree[dep] += 1;
276            }
277        }
278    }
279
280    // Process nodes in order
281    while schedule.len() < graph.nodes.len() {
282        // Find all ready nodes (in-degree 0)
283        let ready: Vec<usize> = (0..graph.nodes.len())
284            .filter(|&i| !scheduled.contains(&i) && in_degree[i] == 0)
285            .collect();
286
287        if ready.is_empty() {
288            break; // No more nodes can be scheduled (possible cycle)
289        }
290
291        // Select node with best memory characteristics
292        let next = select_next_node_memory_aware(graph, &ready);
293        schedule.push(next);
294        scheduled.insert(next);
295
296        // Update in-degrees
297        if let Some(deps) = dependencies.get(&next) {
298            for &dep in deps {
299                if dep < in_degree.len() {
300                    let current_degree: usize = in_degree[dep];
301                    in_degree[dep] = current_degree.saturating_sub(1);
302                }
303            }
304        }
305    }
306
307    schedule
308}
309
310/// Select next node to schedule based on memory characteristics
311fn select_next_node_memory_aware(graph: &EinsumGraph, candidates: &[usize]) -> usize {
312    // Simplified: prefer operations that free memory (have fewer outputs)
313    candidates
314        .iter()
315        .min_by_key(|&&idx| {
316            graph
317                .nodes
318                .get(idx)
319                .map(|n| n.outputs.len())
320                .unwrap_or(usize::MAX)
321        })
322        .copied()
323        .unwrap_or(0)
324}
325
326/// Estimate memory savings from in-place operations
327///
328/// Identifies opportunities where operations could reuse input buffers
329/// for outputs (in-place operations) to reduce memory footprint.
330pub fn analyze_inplace_opportunities(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
331    let mut inplace_candidates = Vec::new();
332
333    for (node_idx, node) in graph.nodes.iter().enumerate() {
334        if can_be_inplace(&node.op) && has_single_input_use(graph, node_idx) {
335            inplace_candidates.push(node_idx);
336        }
337    }
338
339    Ok(inplace_candidates)
340}
341
342/// Check if an operation can be performed in-place
343fn can_be_inplace(op_type: &OpType) -> bool {
344    // Element-wise unary operations can typically be done in-place
345    matches!(op_type, OpType::ElemUnary { .. })
346}
347
348/// Check if a node's input tensor is used only by this node
349fn has_single_input_use(graph: &EinsumGraph, node_idx: usize) -> bool {
350    let node = &graph.nodes[node_idx];
351    if node.inputs.is_empty() {
352        return false;
353    }
354
355    let input_tensor = node.inputs[0];
356
357    // Count how many nodes use this tensor
358    let use_count = graph
359        .nodes
360        .iter()
361        .filter(|n| n.inputs.contains(&input_tensor))
362        .count();
363
364    use_count == 1
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::graph::EinsumNode;
371
372    #[test]
373    fn test_memory_analysis_default() {
374        let analysis = MemoryAnalysis::default();
375        assert_eq!(analysis.peak_memory_bytes, 0);
376        assert_eq!(analysis.total_memory_bytes, 0);
377    }
378
379    #[test]
380    fn test_analyze_empty_graph() {
381        let graph = EinsumGraph::new();
382        let analysis = analyze_memory(&graph, 8).unwrap();
383        assert_eq!(analysis.peak_memory_bytes, 0);
384        assert_eq!(analysis.tensors.len(), 0);
385    }
386
387    #[test]
388    fn test_analyze_single_node() {
389        let mut graph = EinsumGraph::new();
390        let a = graph.add_tensor("A");
391        let b = graph.add_tensor("B");
392        graph
393            .add_node(EinsumNode::elem_unary("relu", a, b))
394            .unwrap();
395
396        let analysis = analyze_memory(&graph, 8).unwrap();
397        assert!(analysis.peak_memory_bytes > 0);
398        assert_eq!(analysis.tensors.len(), 2);
399    }
400
401    #[test]
402    fn test_tensor_lifetime_single_use() {
403        let mut graph = EinsumGraph::new();
404        let a = graph.add_tensor("A");
405        let b = graph.add_tensor("B");
406        graph
407            .add_node(EinsumNode::elem_unary("relu", a, b))
408            .unwrap();
409
410        let lifetimes = analyze_tensor_lifetimes(&graph);
411        assert_eq!(lifetimes[a], (Some(0), Some(0)));
412        assert_eq!(lifetimes[b], (Some(0), Some(0)));
413    }
414
415    #[test]
416    fn test_tensor_lifetime_multiple_uses() {
417        let mut graph = EinsumGraph::new();
418        let a = graph.add_tensor("A");
419        let b = graph.add_tensor("B");
420        let c = graph.add_tensor("C");
421
422        graph
423            .add_node(EinsumNode::elem_unary("relu", a, b))
424            .unwrap();
425        graph
426            .add_node(EinsumNode::elem_unary("tanh", b, c))
427            .unwrap();
428
429        let lifetimes = analyze_tensor_lifetimes(&graph);
430        assert_eq!(lifetimes[b], (Some(0), Some(1)));
431    }
432
433    #[test]
434    fn test_estimate_tensor_size() {
435        let graph = EinsumGraph::new();
436        let size = estimate_tensor_size(&graph, 0, 8);
437        assert_eq!(size, 8000); // 1000 elements * 8 bytes
438    }
439
440    #[test]
441    fn test_is_tensor_live_at_step() {
442        let tensor = TensorMemory {
443            tensor_idx: 0,
444            size_bytes: 1000,
445            first_use: Some(2),
446            last_use: Some(5),
447        };
448
449        assert!(!is_tensor_live_at_step(&tensor, 0));
450        assert!(!is_tensor_live_at_step(&tensor, 1));
451        assert!(is_tensor_live_at_step(&tensor, 2));
452        assert!(is_tensor_live_at_step(&tensor, 3));
453        assert!(is_tensor_live_at_step(&tensor, 5));
454        assert!(!is_tensor_live_at_step(&tensor, 6));
455    }
456
457    #[test]
458    fn test_memory_waste_ratio_zero_peak() {
459        let analysis = MemoryAnalysis {
460            peak_memory_bytes: 0,
461            total_memory_bytes: 1000,
462            avg_utilization: 0.5,
463            ..Default::default()
464        };
465        assert_eq!(analysis.memory_waste_ratio(), 0.0);
466    }
467
468    #[test]
469    fn test_can_be_inplace() {
470        assert!(can_be_inplace(&OpType::ElemUnary {
471            op: "relu".to_string()
472        }));
473        assert!(!can_be_inplace(&OpType::Einsum {
474            spec: "ij,jk->ik".to_string()
475        }));
476    }
477
478    #[test]
479    fn test_analyze_inplace_opportunities_empty() {
480        let graph = EinsumGraph::new();
481        let candidates = analyze_inplace_opportunities(&graph).unwrap();
482        assert!(candidates.is_empty());
483    }
484
485    #[test]
486    fn test_analyze_inplace_single_use() {
487        let mut graph = EinsumGraph::new();
488        let a = graph.add_tensor("A");
489        let b = graph.add_tensor("B");
490        graph
491            .add_node(EinsumNode::elem_unary("relu", a, b))
492            .unwrap();
493
494        let candidates = analyze_inplace_opportunities(&graph).unwrap();
495        assert_eq!(candidates.len(), 1);
496    }
497
498    #[test]
499    fn test_build_dependencies() {
500        let mut graph = EinsumGraph::new();
501        let a = graph.add_tensor("A");
502        let b = graph.add_tensor("B");
503        let c = graph.add_tensor("C");
504
505        graph
506            .add_node(EinsumNode::elem_unary("relu", a, b))
507            .unwrap();
508        graph
509            .add_node(EinsumNode::elem_unary("tanh", b, c))
510            .unwrap();
511
512        let deps = build_dependencies(&graph);
513        assert_eq!(deps.get(&0).unwrap().len(), 0); // Node 0 has no dependencies
514        assert_eq!(deps.get(&1).unwrap(), &vec![0]); // Node 1 depends on node 0
515    }
516
517    #[test]
518    fn test_topological_sort_simple() {
519        let mut graph = EinsumGraph::new();
520        let a = graph.add_tensor("A");
521        let b = graph.add_tensor("B");
522
523        graph
524            .add_node(EinsumNode::elem_unary("relu", a, b))
525            .unwrap();
526
527        let deps = build_dependencies(&graph);
528        let schedule = topological_sort_memory_aware(&graph, &deps);
529        assert_eq!(schedule, vec![0]);
530    }
531}