1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::num_integer::Integer;
4use tract_core::internal::*;
5
6use crate::fact::DeviceTypedFactExt;
7use crate::sync::{DeviceSync, DeviceSyncKind};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct NodeMemReq {
12 pub node: usize,
13 pub lifetime: Lifetime,
14 pub mem_size: TDim,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct Lifetime {
19 pub start: usize,
20 pub end: usize,
21}
22
23impl Lifetime {
24 pub fn is_disjoint(&self, other: &Lifetime) -> bool {
25 self.start >= other.end || other.start >= self.end
26 }
27
28 pub fn is_alive_at_step(&self, step: usize) -> bool {
29 self.start <= step && step < self.end
30 }
31
32 pub fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 pub fn len(&self) -> usize {
37 self.end - self.start
38 }
39}
40
41fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
42 if node.outputs.is_empty() {
43 return None;
44 };
45
46 Some(
47 node.outputs
48 .iter()
49 .flat_map(|o| {
50 o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
51 })
52 .collect(),
53 )
54}
55
56pub fn eval_device_mem_req_for_nodes(
57 model: &TypedModel,
58 order: &[usize],
59) -> TractResult<TVec<NodeMemReq>> {
60 let outputs = model.output_outlets()?.to_vec();
61 let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
62 let Ok(facts) = model.node_output_facts(node.id) else { return false };
63
64 let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
65 nodes.iter().any(|it| {
66 it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
67 })
68 });
69
70 !cpu_sync_in_next_nodes
71 && node.op.name() != "MetalDynKVCache"
72 && node.op.name() != "CudaDynKVCache"
73 && facts
74 .iter()
75 .any(|it| it.to_device_fact().map(|it| it.is_from_device()).unwrap_or(false))
76 });
77 let mut scoped_nodes = tvec![];
78
79 for (step, n) in order.iter().enumerate() {
80 let lifetime_start = step;
81
82 let lifetime_end = flush_lists
83 .iter()
84 .enumerate()
85 .find(|(_step, flush_list)| flush_list.contains(n))
86 .map(|it| usize::min(it.0 + 1, order.len()));
87 let Some(lifetime_end) = lifetime_end else {
89 continue;
90 };
91
92 let out_device_tmp_facts = model
93 .node_output_facts(*n)?
94 .into_iter()
95 .flat_map(|it| it.to_device_fact().ok())
96 .filter(|it| it.is_from_device())
97 .collect::<TVec<_>>();
98
99 if out_device_tmp_facts.is_empty() {
100 continue;
101 }
102
103 scoped_nodes.push(NodeMemReq {
104 node: *n,
105 lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
106 mem_size: out_device_tmp_facts.iter().map(|it| it.mem_size()).sum::<TDim>(),
107 })
108 }
109
110 Ok(scoped_nodes)
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Hash)]
116pub struct Partition {
117 pub nodes: Vec<NodeMemReq>,
118}
119
120impl Partition {
121 pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
122 let mut max_size = self
123 .nodes
124 .iter()
125 .map(|it| it.mem_size.eval_to_i64(symbols))
126 .collect::<TractResult<Vec<_>>>()?
127 .into_iter()
128 .max()
129 .unwrap_or(0);
130 max_size = Integer::next_multiple_of(&max_size, &(vector_size() as i64));
131 Ok(max_size)
132 }
133
134 pub fn size(&self) -> TDim {
135 TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
136 }
137
138 pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
139 self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
140 }
141
142 pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
143 self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct DeviceResolvedMemSchema {
151 pub offsets_by_node: Vec<Option<usize>>,
152 pub memory_size: usize,
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158pub struct DeviceMemSchema {
159 pub model_num_nodes: usize,
161 pub by_partition: Vec<Partition>,
162 pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
164}
165
166impl DeviceMemSchema {
167 pub fn size_by_partition(&self) -> Vec<TDim> {
169 self.by_partition.iter().map(|it| it.size()).collect()
170 }
171
172 pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
174 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
175 }
176
177 pub fn memory_size(&self) -> TDim {
179 self.by_partition.iter().map(|it| it.size()).sum()
180 }
181
182 pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
184 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
185 }
186
187 pub fn compute_offset_by_node(
190 &self,
191 symbols: &SymbolValues,
192 ) -> TractResult<Vec<Option<usize>>> {
193 let mut cursor = 0;
194 let mut offset_by_node = vec![None; self.model_num_nodes];
195
196 for partition in self.by_partition.iter() {
197 for node_mem in partition.nodes.iter() {
198 offset_by_node[node_mem.node] = Some(cursor);
199 }
200 cursor += partition.eval_size_to_i64(symbols)? as usize;
201 }
202
203 Ok(offset_by_node)
204 }
205
206 pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
210 Ok(self
211 .by_steps
212 .iter()
213 .map(|active_nodes| {
214 active_nodes
215 .iter()
216 .flatten()
217 .map(|it| it.mem_size.clone())
218 .sum::<TDim>()
219 .eval_to_i64(symbols)
220 })
221 .collect::<TractResult<Vec<_>>>()?
222 .into_iter()
223 .max()
224 .unwrap_or(0))
225 }
226
227 pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
231 let memory_size = self.eval_memory_size(symbols)? as f32;
232 let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
233 Ok(peak_memory_size / memory_size)
234 }
235}
236
237impl fmt::Display for DeviceMemSchema {
238 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
239 for (step, mem_step) in self.by_steps.iter().enumerate() {
240 writeln!(
241 fmt,
242 "step: {:5} => |{}|",
243 step,
244 mem_step
245 .iter()
246 .map(|n| -> String {
247 n.as_ref()
248 .map(|it| format!("{:^7}", it.node))
249 .unwrap_or(format!("{:^7}", "*"))
250 })
251 .collect::<Vec<String>>()
252 .join("|")
253 )?;
254 }
255 writeln!(fmt, "memory_size: {}", self.memory_size())?;
256 Ok(())
257 }
258}
259
260impl DeviceMemSchema {
261 pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
263 Ok(DeviceResolvedMemSchema {
264 offsets_by_node: self.compute_offset_by_node(symbols)?,
265 memory_size: self.eval_memory_size(symbols)?.try_into()?,
266 })
267 }
268
269 pub fn build(
273 model: &TypedModel,
274 order: &[usize],
275 hint: &SymbolValues,
276 ) -> TractResult<DeviceMemSchema> {
277 let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
278
279 let hinted_mem_size = nodes_mem_req
280 .iter()
281 .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?)))
282 .collect::<TractResult<HashMap<usize, i64>>>()?;
283
284 nodes_mem_req.sort_by(|lhs, rhs| {
285 let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node);
286 let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node);
287 lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
288 });
289
290 let mut partitions: Vec<Partition> = vec![];
291 for node_mem in nodes_mem_req {
292 let mut available = partitions
294 .iter_mut()
295 .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
296 .collect::<Vec<_>>();
297
298 available.sort_by_cached_key(|n| {
299 -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>()
300 });
301
302 match available.first_mut() {
303 Some(available) => {
304 available.nodes.push(node_mem);
305 }
306 None => partitions.push(Partition { nodes: vec![node_mem] }),
307 }
308 }
309
310 let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
311 .map(|step| {
312 let mem_step: Vec<_> =
313 partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
314 ensure!(mem_step.len() <= partitions.len());
315 Ok(mem_step)
316 })
317 .collect::<TractResult<Vec<_>>>()?;
318
319 Ok(DeviceMemSchema {
320 model_num_nodes: model.nodes().len(),
321 by_partition: partitions,
322 by_steps,
323 })
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_lifetime_is_disjoint() {
333 let l1 = Lifetime { start: 0, end: 5 };
334 let l2 = Lifetime { start: 5, end: 10 };
335 let l3 = Lifetime { start: 3, end: 7 };
336
337 assert!(l1.is_disjoint(&l2));
338 assert!(l2.is_disjoint(&l1));
339 assert!(!l1.is_disjoint(&l3));
340 assert!(!l3.is_disjoint(&l2));
341 }
342
343 #[test]
344 fn test_lifetime_is_alive_at_step() {
345 let lifetime = Lifetime { start: 5, end: 10 };
346
347 assert!(!lifetime.is_alive_at_step(4));
348 assert!(lifetime.is_alive_at_step(5));
349 assert!(lifetime.is_alive_at_step(7));
350 assert!(lifetime.is_alive_at_step(9));
351 assert!(!lifetime.is_alive_at_step(10));
352 }
353
354 #[test]
355 fn test_empty_lifetime() {
356 let lifetime = Lifetime { start: 5, end: 5 };
357 assert!(lifetime.is_empty());
358 assert_eq!(lifetime.len(), 0);
359 }
360
361 #[test]
362 fn test_node_mem_req_basic() {
363 let req =
364 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
365
366 assert_eq!(req.node, 1);
367 assert_eq!(req.lifetime.start, 0);
368 assert_eq!(req.lifetime.end, 5);
369 assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
370 }
371
372 #[test]
373 fn test_partition_has_no_conflict() {
374 let node1 =
375 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
376
377 let partition = Partition { nodes: vec![node1] };
378
379 assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
380 assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
381 }
382
383 #[test]
384 fn test_partition_find_node() {
385 let node1 =
386 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
387
388 let node2 =
389 NodeMemReq { node: 2, lifetime: Lifetime { start: 5, end: 10 }, mem_size: 2000.into() };
390
391 let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
392
393 assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
394 assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
395 assert_eq!(partition.find_node_alive_at_step(10), None);
396 }
397}