tensorlogic_infer/
memory.rs1use tensorlogic_ir::{EinsumGraph, OpType};
4
5use crate::capabilities::DType;
6
7#[derive(Debug, Clone)]
9pub struct TensorMemory {
10 pub tensor_idx: usize,
11 pub shape: Vec<usize>,
12 pub element_count: usize,
13 pub bytes: usize,
14}
15
16impl TensorMemory {
17 pub fn new(tensor_idx: usize, shape: Vec<usize>, dtype: DType) -> Self {
18 let element_count: usize = shape.iter().product();
19 let bytes = element_count * dtype.byte_size();
20
21 TensorMemory {
22 tensor_idx,
23 shape,
24 element_count,
25 bytes,
26 }
27 }
28
29 pub fn megabytes(&self) -> f64 {
30 self.bytes as f64 / (1024.0 * 1024.0)
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct MemoryEstimate {
37 pub input_memory: Vec<TensorMemory>,
38 pub intermediate_memory: Vec<TensorMemory>,
39 pub output_memory: Vec<TensorMemory>,
40 pub total_bytes: usize,
41 pub peak_bytes: usize,
42}
43
44impl MemoryEstimate {
45 pub fn new() -> Self {
46 MemoryEstimate {
47 input_memory: Vec::new(),
48 intermediate_memory: Vec::new(),
49 output_memory: Vec::new(),
50 total_bytes: 0,
51 peak_bytes: 0,
52 }
53 }
54
55 pub fn total_megabytes(&self) -> f64 {
56 self.total_bytes as f64 / (1024.0 * 1024.0)
57 }
58
59 pub fn peak_megabytes(&self) -> f64 {
60 self.peak_bytes as f64 / (1024.0 * 1024.0)
61 }
62
63 pub fn summary(&self) -> String {
64 format!(
65 "Memory Estimate:\n\
66 - Inputs: {} tensors, {:.2} MB\n\
67 - Intermediates: {} tensors, {:.2} MB\n\
68 - Outputs: {} tensors, {:.2} MB\n\
69 - Total: {:.2} MB\n\
70 - Peak: {:.2} MB",
71 self.input_memory.len(),
72 self.input_memory.iter().map(|t| t.megabytes()).sum::<f64>(),
73 self.intermediate_memory.len(),
74 self.intermediate_memory
75 .iter()
76 .map(|t| t.megabytes())
77 .sum::<f64>(),
78 self.output_memory.len(),
79 self.output_memory
80 .iter()
81 .map(|t| t.megabytes())
82 .sum::<f64>(),
83 self.total_megabytes(),
84 self.peak_megabytes()
85 )
86 }
87}
88
89impl Default for MemoryEstimate {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95pub struct MemoryEstimator {
97 dtype: DType,
98}
99
100impl MemoryEstimator {
101 pub fn new(dtype: DType) -> Self {
102 MemoryEstimator { dtype }
103 }
104
105 pub fn estimate(&self, graph: &EinsumGraph) -> MemoryEstimate {
108 let mut estimate = MemoryEstimate::new();
109 let default_shape = vec![10]; for idx in 0..graph.tensors.len() {
113 let mem = TensorMemory::new(idx, default_shape.clone(), self.dtype);
114 estimate.total_bytes += mem.bytes;
115 estimate.input_memory.push(mem);
116 }
117
118 let num_inputs = graph.tensors.len();
120 for (node_idx, node) in graph.nodes.iter().enumerate() {
121 let tensor_idx = num_inputs + node_idx;
122 let shape = self.estimate_output_shape(node, graph);
123
124 let mem = TensorMemory::new(tensor_idx, shape, self.dtype);
125 estimate.total_bytes += mem.bytes;
126
127 if node_idx == graph.nodes.len() - 1 {
129 estimate.output_memory.push(mem);
130 } else {
131 estimate.intermediate_memory.push(mem);
132 }
133 }
134
135 estimate.peak_bytes = estimate.total_bytes;
138
139 estimate
140 }
141
142 pub fn estimate_with_lifetime(&self, graph: &EinsumGraph) -> MemoryEstimate {
144 let mut estimate = self.estimate(graph);
145
146 let num_tensors = graph.tensors.len() + graph.nodes.len();
148 let mut alive = vec![false; num_tensors];
149
150 for item in alive.iter_mut().take(graph.tensors.len()) {
152 *item = true;
153 }
154
155 let mut peak_bytes = 0;
156
157 for (node_idx, node) in graph.nodes.iter().enumerate() {
158 let output_idx = graph.tensors.len() + node_idx;
160 alive[output_idx] = true;
161
162 let current_bytes: usize = alive
164 .iter()
165 .enumerate()
166 .filter(|(_, &is_alive)| is_alive)
167 .map(|(idx, _)| {
168 if idx < graph.tensors.len() {
169 &estimate.input_memory[idx]
171 } else {
172 let node_offset = idx - graph.tensors.len();
174 if node_offset < estimate.intermediate_memory.len() {
175 &estimate.intermediate_memory[node_offset]
176 } else {
177 &estimate.output_memory[0]
178 }
179 }
180 })
181 .map(|mem| mem.bytes)
182 .sum();
183
184 peak_bytes = peak_bytes.max(current_bytes);
185
186 for &input_idx in &node.inputs {
189 if self.is_last_use(input_idx, node_idx, graph) {
190 alive[input_idx] = false;
191 }
192 }
193 }
194
195 estimate.peak_bytes = peak_bytes;
196 estimate
197 }
198
199 fn estimate_output_shape(
200 &self,
201 node: &tensorlogic_ir::EinsumNode,
202 _graph: &EinsumGraph,
203 ) -> Vec<usize> {
204 match &node.op {
206 OpType::Einsum { spec } => {
207 if let Some(arrow_pos) = spec.find("->") {
209 let output_axes = &spec[arrow_pos + 2..];
210 vec![10; output_axes.len()]
212 } else {
213 vec![10]
215 }
216 }
217 OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
218 vec![10]
220 }
221 OpType::Reduce { op: _, axes } => {
222 let default_shape = vec![10, 10]; let mut shape = default_shape.clone();
225 for &axis in axes.iter().rev() {
226 if axis < shape.len() {
227 shape.remove(axis);
228 }
229 }
230 if shape.is_empty() {
231 vec![1]
232 } else {
233 shape
234 }
235 }
236 }
237 }
238
239 fn is_last_use(&self, tensor_idx: usize, current_node: usize, graph: &EinsumGraph) -> bool {
240 for (node_idx, node) in graph.nodes.iter().enumerate() {
242 if node_idx > current_node && node.inputs.contains(&tensor_idx) {
243 return false;
244 }
245 }
246 true
247 }
248
249 pub fn estimate_batch(&self, graph: &EinsumGraph, batch_size: usize) -> MemoryEstimate {
251 let single_estimate = self.estimate(graph);
252
253 let mut batch_estimate = MemoryEstimate::new();
254 batch_estimate.total_bytes = single_estimate.total_bytes * batch_size;
255 batch_estimate.peak_bytes = single_estimate.peak_bytes * batch_size;
256
257 for input in &single_estimate.input_memory {
259 let mut batched = input.clone();
260 batched.bytes *= batch_size;
261 batch_estimate.input_memory.push(batched);
262 }
263
264 for intermediate in &single_estimate.intermediate_memory {
265 let mut batched = intermediate.clone();
266 batched.bytes *= batch_size;
267 batch_estimate.intermediate_memory.push(batched);
268 }
269
270 for output in &single_estimate.output_memory {
271 let mut batched = output.clone();
272 batched.bytes *= batch_size;
273 batch_estimate.output_memory.push(batched);
274 }
275
276 batch_estimate
277 }
278}
279
280impl Default for MemoryEstimator {
281 fn default() -> Self {
282 Self::new(DType::F64)
283 }
284}