Skip to main content

tensorlogic_infer/
memory.rs

1//! Memory estimation utilities for execution planning.
2
3use tensorlogic_ir::{EinsumGraph, OpType};
4
5use crate::capabilities::DType;
6
7/// Memory usage estimate for a tensor
8#[derive(Debug, Clone)]
9pub struct TensorMemory {
10    pub tensor_idx: usize,
11    pub shape: Vec<usize>,
12    pub element_count: usize,
13    pub bytes: usize,
14}
15
16impl TensorMemory {
17    pub fn new(tensor_idx: usize, shape: Vec<usize>, dtype: DType) -> Self {
18        let element_count: usize = shape.iter().product();
19        let bytes = element_count * dtype.byte_size();
20
21        TensorMemory {
22            tensor_idx,
23            shape,
24            element_count,
25            bytes,
26        }
27    }
28
29    pub fn megabytes(&self) -> f64 {
30        self.bytes as f64 / (1024.0 * 1024.0)
31    }
32}
33
34/// Complete memory profile for graph execution
35#[derive(Debug, Clone)]
36pub struct MemoryEstimate {
37    pub input_memory: Vec<TensorMemory>,
38    pub intermediate_memory: Vec<TensorMemory>,
39    pub output_memory: Vec<TensorMemory>,
40    pub total_bytes: usize,
41    pub peak_bytes: usize,
42}
43
44impl MemoryEstimate {
45    pub fn new() -> Self {
46        MemoryEstimate {
47            input_memory: Vec::new(),
48            intermediate_memory: Vec::new(),
49            output_memory: Vec::new(),
50            total_bytes: 0,
51            peak_bytes: 0,
52        }
53    }
54
55    pub fn total_megabytes(&self) -> f64 {
56        self.total_bytes as f64 / (1024.0 * 1024.0)
57    }
58
59    pub fn peak_megabytes(&self) -> f64 {
60        self.peak_bytes as f64 / (1024.0 * 1024.0)
61    }
62
63    pub fn summary(&self) -> String {
64        format!(
65            "Memory Estimate:\n\
66             - Inputs: {} tensors, {:.2} MB\n\
67             - Intermediates: {} tensors, {:.2} MB\n\
68             - Outputs: {} tensors, {:.2} MB\n\
69             - Total: {:.2} MB\n\
70             - Peak: {:.2} MB",
71            self.input_memory.len(),
72            self.input_memory.iter().map(|t| t.megabytes()).sum::<f64>(),
73            self.intermediate_memory.len(),
74            self.intermediate_memory
75                .iter()
76                .map(|t| t.megabytes())
77                .sum::<f64>(),
78            self.output_memory.len(),
79            self.output_memory
80                .iter()
81                .map(|t| t.megabytes())
82                .sum::<f64>(),
83            self.total_megabytes(),
84            self.peak_megabytes()
85        )
86    }
87}
88
89impl Default for MemoryEstimate {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95/// Memory estimator for execution graphs
96pub struct MemoryEstimator {
97    dtype: DType,
98}
99
100impl MemoryEstimator {
101    pub fn new(dtype: DType) -> Self {
102        MemoryEstimator { dtype }
103    }
104
105    /// Estimate memory usage for a graph
106    /// Note: Uses default shape `[10]` for all tensors since graph only stores names
107    pub fn estimate(&self, graph: &EinsumGraph) -> MemoryEstimate {
108        let mut estimate = MemoryEstimate::new();
109        let default_shape = vec![10]; // Default shape for estimation
110
111        // Estimate input tensors
112        for idx in 0..graph.tensors.len() {
113            let mem = TensorMemory::new(idx, default_shape.clone(), self.dtype);
114            estimate.total_bytes += mem.bytes;
115            estimate.input_memory.push(mem);
116        }
117
118        // Estimate intermediate/output tensors from nodes
119        let num_inputs = graph.tensors.len();
120        for (node_idx, node) in graph.nodes.iter().enumerate() {
121            let tensor_idx = num_inputs + node_idx;
122            let shape = self.estimate_output_shape(node, graph);
123
124            let mem = TensorMemory::new(tensor_idx, shape, self.dtype);
125            estimate.total_bytes += mem.bytes;
126
127            // Last node is output, others are intermediates
128            if node_idx == graph.nodes.len() - 1 {
129                estimate.output_memory.push(mem);
130            } else {
131                estimate.intermediate_memory.push(mem);
132            }
133        }
134
135        // Peak memory is when all tensors are alive
136        // (simplified - doesn't account for tensor lifetime)
137        estimate.peak_bytes = estimate.total_bytes;
138
139        estimate
140    }
141
142    /// Estimate memory with tensor lifetime analysis for peak usage
143    pub fn estimate_with_lifetime(&self, graph: &EinsumGraph) -> MemoryEstimate {
144        let mut estimate = self.estimate(graph);
145
146        // Track which tensors are alive at each point
147        let num_tensors = graph.tensors.len() + graph.nodes.len();
148        let mut alive = vec![false; num_tensors];
149
150        // Input tensors are initially alive
151        for item in alive.iter_mut().take(graph.tensors.len()) {
152            *item = true;
153        }
154
155        let mut peak_bytes = 0;
156
157        for (node_idx, node) in graph.nodes.iter().enumerate() {
158            // Mark output tensor as alive
159            let output_idx = graph.tensors.len() + node_idx;
160            alive[output_idx] = true;
161
162            // Calculate current memory usage
163            let current_bytes: usize = alive
164                .iter()
165                .enumerate()
166                .filter(|(_, &is_alive)| is_alive)
167                .map(|(idx, _)| {
168                    if idx < graph.tensors.len() {
169                        // Input tensor
170                        &estimate.input_memory[idx]
171                    } else {
172                        // Intermediate/output tensor
173                        let node_offset = idx - graph.tensors.len();
174                        if node_offset < estimate.intermediate_memory.len() {
175                            &estimate.intermediate_memory[node_offset]
176                        } else {
177                            &estimate.output_memory[0]
178                        }
179                    }
180                })
181                .map(|mem| mem.bytes)
182                .sum();
183
184            peak_bytes = peak_bytes.max(current_bytes);
185
186            // Mark input tensors as dead if no longer needed
187            // (simplified: assume each tensor is used only once)
188            for &input_idx in &node.inputs {
189                if self.is_last_use(input_idx, node_idx, graph) {
190                    alive[input_idx] = false;
191                }
192            }
193        }
194
195        estimate.peak_bytes = peak_bytes;
196        estimate
197    }
198
199    fn estimate_output_shape(
200        &self,
201        node: &tensorlogic_ir::EinsumNode,
202        _graph: &EinsumGraph,
203    ) -> Vec<usize> {
204        // Since graph only stores tensor names, we use default shapes for estimation
205        match &node.op {
206            OpType::Einsum { spec } => {
207                // Simplified: parse einsum spec to estimate shape
208                if let Some(arrow_pos) = spec.find("->") {
209                    let output_axes = &spec[arrow_pos + 2..];
210                    // Estimate each dimension as 10 (placeholder)
211                    vec![10; output_axes.len()]
212                } else {
213                    // Default shape
214                    vec![10]
215                }
216            }
217            OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
218                // Shape preserved for element-wise ops - use default
219                vec![10]
220            }
221            OpType::Reduce { op: _, axes } => {
222                // Remove reduced axes from default shape
223                let default_shape = vec![10, 10]; // Assume 2D default
224                let mut shape = default_shape.clone();
225                for &axis in axes.iter().rev() {
226                    if axis < shape.len() {
227                        shape.remove(axis);
228                    }
229                }
230                if shape.is_empty() {
231                    vec![1]
232                } else {
233                    shape
234                }
235            }
236        }
237    }
238
239    fn is_last_use(&self, tensor_idx: usize, current_node: usize, graph: &EinsumGraph) -> bool {
240        // Check if any later nodes use this tensor
241        for (node_idx, node) in graph.nodes.iter().enumerate() {
242            if node_idx > current_node && node.inputs.contains(&tensor_idx) {
243                return false;
244            }
245        }
246        true
247    }
248
249    /// Estimate memory for a batch of inputs
250    pub fn estimate_batch(&self, graph: &EinsumGraph, batch_size: usize) -> MemoryEstimate {
251        let single_estimate = self.estimate(graph);
252
253        let mut batch_estimate = MemoryEstimate::new();
254        batch_estimate.total_bytes = single_estimate.total_bytes * batch_size;
255        batch_estimate.peak_bytes = single_estimate.peak_bytes * batch_size;
256
257        // Scale all tensors by batch size
258        for input in &single_estimate.input_memory {
259            let mut batched = input.clone();
260            batched.bytes *= batch_size;
261            batch_estimate.input_memory.push(batched);
262        }
263
264        for intermediate in &single_estimate.intermediate_memory {
265            let mut batched = intermediate.clone();
266            batched.bytes *= batch_size;
267            batch_estimate.intermediate_memory.push(batched);
268        }
269
270        for output in &single_estimate.output_memory {
271            let mut batched = output.clone();
272            batched.bytes *= batch_size;
273            batch_estimate.output_memory.push(batched);
274        }
275
276        batch_estimate
277    }
278}
279
280impl Default for MemoryEstimator {
281    fn default() -> Self {
282        Self::new(DType::F64)
283    }
284}