1use std::fmt;
2use std::fmt::Debug;
3use tract_core::internal::*;
4
5use crate::fact::DeviceTypedFactExt;
6use crate::sync::{DeviceSync, DeviceSyncKind};
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct NodeMemReq {
11 pub node: usize,
12 pub lifetime: Lifetime,
13 pub mem_size: TDim,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct Lifetime {
18 pub start: usize,
19 pub end: usize,
20}
21
22impl Lifetime {
23 pub fn is_disjoint(&self, other: &Lifetime) -> bool {
24 self.start >= other.end || other.start >= self.end
25 }
26
27 pub fn is_alive_at_step(&self, step: usize) -> bool {
28 self.start <= step && step < self.end
29 }
30
31 pub fn is_empty(&self) -> bool {
32 self.len() == 0
33 }
34
35 pub fn len(&self) -> usize {
36 self.end - self.start
37 }
38}
39
40fn next_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<TVec<&'a TypedNode>> {
41 if node.outputs.is_empty() {
42 return None;
43 };
44
45 Some(
46 node.outputs
47 .iter()
48 .flat_map(|o| {
49 o.successors.iter().map(|succ| &model.nodes()[succ.node]).collect::<Vec<_>>()
50 })
51 .collect(),
52 )
53}
54
55pub fn eval_device_mem_req_for_nodes(
56 model: &TypedModel,
57 order: &[usize],
58) -> TractResult<TVec<NodeMemReq>> {
59 let outputs = model.output_outlets()?.to_vec();
60 let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
61 let Ok(facts) = model.node_output_facts(node.id) else { return false };
62
63 let cpu_sync_in_next_nodes = next_nodes(model, node).is_some_and(|nodes| {
64 nodes.iter().any(|it| {
65 it.op_as::<DeviceSync>().is_some_and(|op| op.kind == DeviceSyncKind::ToHost)
66 })
67 });
68
69 !cpu_sync_in_next_nodes
70 && facts
71 .iter()
72 .any(|it| it.to_device_fact().map(|it| it.is_from_device()).unwrap_or(false))
73 });
74 let mut scoped_nodes = tvec![];
75
76 for (step, n) in order.iter().enumerate() {
77 let lifetime_start = step;
78 let lifetime_end = flush_lists
79 .iter()
80 .enumerate()
81 .find(|(_step, flush_list)| flush_list.contains(n))
82 .map(|it| usize::min(it.0 + 1, order.len()));
83
84 let Some(lifetime_end) = lifetime_end else {
86 continue;
87 };
88
89 let out_device_tmp_facts = model
90 .node_output_facts(*n)?
91 .into_iter()
92 .flat_map(|it| it.to_device_fact().ok())
93 .filter(|it| it.is_from_device())
94 .collect::<TVec<_>>();
95
96 if out_device_tmp_facts.is_empty() {
97 continue;
98 }
99
100 scoped_nodes.push(NodeMemReq {
101 node: *n,
102 lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
103 mem_size: out_device_tmp_facts.iter().map(|it| it.mem_size()).sum::<TDim>(),
104 })
105 }
106
107 Ok(scoped_nodes)
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Hash)]
113pub struct Partition {
114 pub nodes: Vec<NodeMemReq>,
115}
116
117impl Partition {
118 pub fn eval_size_to_i64(&self, symbols: &SymbolValues) -> TractResult<i64> {
119 Ok(self
120 .nodes
121 .iter()
122 .map(|it| it.mem_size.eval_to_i64(symbols))
123 .collect::<TractResult<Vec<_>>>()?
124 .into_iter()
125 .max()
126 .unwrap_or(0))
127 }
128
129 pub fn size(&self) -> TDim {
130 TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()).simplify()
131 }
132
133 pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
134 self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
135 }
136
137 pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
138 self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct DeviceResolvedMemSchema {
146 pub offsets_by_node: Vec<Option<usize>>,
147 pub memory_size: usize,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Hash)]
153pub struct DeviceMemSchema {
154 pub model_num_nodes: usize,
156 pub by_partition: Vec<Partition>,
157 pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
159}
160
161impl DeviceMemSchema {
162 pub fn size_by_partition(&self) -> Vec<TDim> {
164 self.by_partition.iter().map(|it| it.size()).collect()
165 }
166
167 pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
169 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
170 }
171
172 pub fn memory_size(&self) -> TDim {
174 self.by_partition.iter().map(|it| it.size()).sum()
175 }
176
177 pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
179 self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
180 }
181
182 pub fn compute_offset_by_node(
185 &self,
186 symbols: &SymbolValues,
187 ) -> TractResult<Vec<Option<usize>>> {
188 let mut cursor = 0;
189 let mut offset_by_node = vec![None; self.model_num_nodes];
190
191 for partition in self.by_partition.iter() {
192 for node_mem in partition.nodes.iter() {
193 offset_by_node[node_mem.node] = Some(cursor);
194 }
195 cursor += partition.eval_size_to_i64(symbols)? as usize;
196 }
197
198 Ok(offset_by_node)
199 }
200
201 pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
205 Ok(self
206 .by_steps
207 .iter()
208 .map(|active_nodes| {
209 active_nodes
210 .iter()
211 .flatten()
212 .map(|it| it.mem_size.clone())
213 .sum::<TDim>()
214 .eval_to_i64(symbols)
215 })
216 .collect::<TractResult<Vec<_>>>()?
217 .into_iter()
218 .max()
219 .unwrap_or(0))
220 }
221
222 pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
226 let memory_size = self.eval_memory_size(symbols)? as f32;
227 let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
228 Ok(peak_memory_size / memory_size)
229 }
230}
231
232impl fmt::Display for DeviceMemSchema {
233 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
234 for (step, mem_step) in self.by_steps.iter().enumerate() {
235 writeln!(
236 fmt,
237 "step: {:5} => |{}|",
238 step,
239 mem_step
240 .iter()
241 .map(|n| -> String {
242 n.as_ref()
243 .map(|it| format!("{:^7}", it.node))
244 .unwrap_or(format!("{:^7}", "*"))
245 })
246 .collect::<Vec<String>>()
247 .join("|")
248 )?;
249 }
250 writeln!(fmt, "memory_size: {}", self.memory_size())?;
251 Ok(())
252 }
253}
254
255impl DeviceMemSchema {
256 pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<DeviceResolvedMemSchema> {
258 Ok(DeviceResolvedMemSchema {
259 offsets_by_node: self.compute_offset_by_node(symbols)?,
260 memory_size: self.eval_memory_size(symbols)?.try_into()?,
261 })
262 }
263
264 pub fn build(
268 model: &TypedModel,
269 order: &[usize],
270 hint: &SymbolValues,
271 ) -> TractResult<DeviceMemSchema> {
272 let mut nodes_mem_req = eval_device_mem_req_for_nodes(model, order)?;
273
274 let hinted_mem_size = nodes_mem_req
275 .iter()
276 .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?)))
277 .collect::<TractResult<HashMap<usize, i64>>>()?;
278
279 nodes_mem_req.sort_by(|lhs, rhs| {
280 let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node);
281 let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node);
282 lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()
283 });
284
285 let mut partitions: Vec<Partition> = vec![];
286 for node_mem in nodes_mem_req {
287 let mut available = partitions
289 .iter_mut()
290 .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
291 .collect::<Vec<_>>();
292
293 available.sort_by_cached_key(|n| {
294 -n.nodes.iter().flat_map(|it| hinted_mem_size.get(&it.node)).sum::<i64>()
295 });
296
297 match available.first_mut() {
298 Some(available) => {
299 available.nodes.push(node_mem);
300 }
301 None => partitions.push(Partition { nodes: vec![node_mem] }),
302 }
303 }
304
305 let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
306 .map(|step| {
307 let mem_step: Vec<_> =
308 partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
309 ensure!(mem_step.len() <= partitions.len());
310 Ok(mem_step)
311 })
312 .collect::<TractResult<Vec<_>>>()?;
313
314 Ok(DeviceMemSchema {
315 model_num_nodes: model.nodes().len(),
316 by_partition: partitions,
317 by_steps,
318 })
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_lifetime_is_disjoint() {
328 let l1 = Lifetime { start: 0, end: 5 };
329 let l2 = Lifetime { start: 5, end: 10 };
330 let l3 = Lifetime { start: 3, end: 7 };
331
332 assert!(l1.is_disjoint(&l2));
333 assert!(l2.is_disjoint(&l1));
334 assert!(!l1.is_disjoint(&l3));
335 assert!(!l3.is_disjoint(&l2));
336 }
337
338 #[test]
339 fn test_lifetime_is_alive_at_step() {
340 let lifetime = Lifetime { start: 5, end: 10 };
341
342 assert!(!lifetime.is_alive_at_step(4));
343 assert!(lifetime.is_alive_at_step(5));
344 assert!(lifetime.is_alive_at_step(7));
345 assert!(lifetime.is_alive_at_step(9));
346 assert!(!lifetime.is_alive_at_step(10));
347 }
348
349 #[test]
350 fn test_empty_lifetime() {
351 let lifetime = Lifetime { start: 5, end: 5 };
352 assert!(lifetime.is_empty());
353 assert_eq!(lifetime.len(), 0);
354 }
355
356 #[test]
357 fn test_node_mem_req_basic() {
358 let req =
359 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
360
361 assert_eq!(req.node, 1);
362 assert_eq!(req.lifetime.start, 0);
363 assert_eq!(req.lifetime.end, 5);
364 assert_eq!(req.mem_size.to_i64().unwrap(), 1000);
365 }
366
367 #[test]
368 fn test_partition_has_no_conflict() {
369 let node1 =
370 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
371
372 let partition = Partition { nodes: vec![node1] };
373
374 assert!(partition.has_no_conflict_with_lifetime(&Lifetime { start: 5, end: 10 }));
375 assert!(!partition.has_no_conflict_with_lifetime(&Lifetime { start: 3, end: 7 }));
376 }
377
378 #[test]
379 fn test_partition_find_node() {
380 let node1 =
381 NodeMemReq { node: 1, lifetime: Lifetime { start: 0, end: 5 }, mem_size: 1000.into() };
382
383 let node2 =
384 NodeMemReq { node: 2, lifetime: Lifetime { start: 5, end: 10 }, mem_size: 2000.into() };
385
386 let partition = Partition { nodes: vec![node1.clone(), node2.clone()] };
387
388 assert_eq!(partition.find_node_alive_at_step(3), Some(&node1));
389 assert_eq!(partition.find_node_alive_at_step(7), Some(&node2));
390 assert_eq!(partition.find_node_alive_at_step(10), None);
391 }
392}